Skip to content

Sac

mighty.mighty_agents.sac #

MightySACAgent #

MightySACAgent(
    output_dir: Path,
    env: MIGHTYENV,
    eval_env: Optional[MIGHTYENV] = None,
    seed: Optional[int] = None,
    n_policy_units: int = 64,
    soft_update_weight: float = 0.005,
    batch_size: int = 256,
    learning_starts: int = 10000,
    update_every: int = 50,
    n_gradient_steps: int = 1,
    policy_lr: float = 0.0003,
    q_lr: float = 0.0003,
    gamma: float = 0.99,
    alpha: float = 0.2,
    auto_alpha: bool = True,
    target_entropy: Optional[float] = None,
    alpha_lr: float = 0.0003,
    hidden_sizes: Optional[List[int]] = None,
    activation: str = "relu",
    log_std_min: float = -5,
    log_std_max: float = 2,
    render_progress: bool = True,
    log_wandb: bool = False,
    wandb_kwargs: Optional[Dict] = None,
    replay_buffer_class: Type[MightyReplay] = MightyReplay,
    replay_buffer_kwargs: Optional[TypeKwargs] = None,
    meta_methods: Optional[List[Union[str, type]]] = None,
    meta_kwargs: Optional[List[TypeKwargs]] = None,
    policy_class: Optional[
        Union[
            str, DictConfig, Type[MightyExplorationPolicy]
        ]
    ] = None,
    policy_kwargs: Optional[Dict] = None,
    normalize_obs: bool = True,
    normalize_reward: bool = True,
    rescale_action: bool = False,
    policy_frequency: int = 2,
    target_network_frequency: int = 1,
)

Bases: MightyAgent

Source code in mighty/mighty_agents/sac.py
def __init__(
    self,
    output_dir: Path,
    env: MIGHTYENV,
    eval_env: Optional[MIGHTYENV] = None,
    seed: Optional[int] = None,
    n_policy_units: int = 64,
    soft_update_weight: float = 0.005,
    # --- Replay & update scheduling ---
    batch_size: int = 256,
    learning_starts: int = 10000,
    update_every: int = 50,
    n_gradient_steps: int = 1,
    # --- Learning rates ---
    policy_lr: float = 3e-4,
    q_lr: float = 3e-4,
    # --- SAC hyperparameters ---
    gamma: float = 0.99,
    alpha: float = 0.2,
    auto_alpha: bool = True,
    target_entropy: Optional[float] = None,
    alpha_lr: float = 3e-4,
    # --- Network architecture (optional override) ---
    hidden_sizes: Optional[List[int]] = None,
    activation: str = "relu",
    log_std_min: float = -5,
    log_std_max: float = 2,
    # --- Logging & buffer ---
    render_progress: bool = True,
    log_wandb: bool = False,
    wandb_kwargs: Optional[Dict] = None,
    replay_buffer_class: Type[MightyReplay] = MightyReplay,
    replay_buffer_kwargs: Optional[TypeKwargs] = None,
    meta_methods: Optional[List[Union[str, type]]] = None,
    meta_kwargs: Optional[List[TypeKwargs]] = None,
    policy_class: Optional[
        Union[str, DictConfig, Type[MightyExplorationPolicy]]
    ] = None,
    policy_kwargs: Optional[Dict] = None,
    normalize_obs: bool = True,  # ← NEW
    normalize_reward: bool = True,  # ← NEW (optional),
    rescale_action: bool = False,  # ← NEW Whether to rescale actions to the environment's action space
    policy_frequency: int = 2,  # Frequency of policy updates
    target_network_frequency: int = 1,  # Frequency of target network updates
):
    """Initialize SAC agent with tunable hyperparameters and backward-compatible names."""
    if hidden_sizes is None:
        hidden_sizes = [n_policy_units, n_policy_units]
    tau = soft_update_weight

    # Save hyperparameters
    self.batch_size = batch_size
    self.learning_starts = learning_starts
    self.update_every = update_every
    self.n_gradient_steps = n_gradient_steps
    self.policy_lr = policy_lr
    self.q_lr = q_lr
    self.gamma = gamma
    self.tau = tau
    self.alpha = alpha
    self.hidden_sizes = hidden_sizes
    self.activation = activation
    self.log_std_min = log_std_min
    self.log_std_max = log_std_max

    self.auto_alpha = auto_alpha
    self.target_entropy = target_entropy
    self.alpha_lr = alpha_lr

    # Placeholders for model and updater
    self.model: SACModel | None = None
    self.update_fn: SACUpdate | None = None

    # Exploration policy class
    self.policy_class = retrieve_class(
        cls=policy_class, default_cls=StochasticPolicy
    )
    self.policy_kwargs = policy_kwargs or {
        "discrete": False  # Default to continuous SAC
    }

    self.policy_frequency = policy_frequency
    self.target_network_frequency = target_network_frequency

    super().__init__(
        env=env,
        output_dir=output_dir,
        seed=seed,
        eval_env=eval_env,
        learning_starts=learning_starts,
        n_gradient_steps=n_gradient_steps,
        render_progress=render_progress,
        log_wandb=log_wandb,
        wandb_kwargs=wandb_kwargs,
        replay_buffer_class=replay_buffer_class,
        replay_buffer_kwargs=replay_buffer_kwargs,
        meta_methods=meta_methods,
        meta_kwargs=meta_kwargs,
        normalize_obs=normalize_obs,
        normalize_reward=normalize_reward,
        rescale_action=rescale_action,
        batch_size=batch_size,
        learning_rate=policy_lr,  # For compatibility with base class
    )

    # Initialize loss buffer for logging
    self.loss_buffer = {
        "Update/q_loss1": [],
        "Update/q_loss2": [],
        "Update/policy_loss": [],
        "Update/td_error1": [],
        "Update/td_error2": [],
        "step": [],
    }

parameters property #

parameters: List[Parameter]

Collect policy + Q‐network parameters for SAC.

value_function property #

value_function: Module

Value function for compatibility: V(s) = min(Q1,Q2)(s, a_policy) - alpha * log_pi(a|s).

__del__ #

__del__() -> None

Close wandb upon deletion.

Source code in mighty/mighty_agents/base_agent.py
def __del__(self) -> None:
    """Close wandb upon deletion."""
    self.env.close()  # type: ignore
    if self.log_wandb:
        wandb.finish()

adapt_hps #

adapt_hps(metrics: Dict) -> None

Set hyperparameters.

Source code in mighty/mighty_agents/base_agent.py
def adapt_hps(self, metrics: Dict) -> None:
    """Set hyperparameters."""
    old_hps = {
        "step": self.steps,
        "hp/lr": self.learning_rate,
        "hp/pi_epsilon": self._epsilon,
        "hp/batch_size": self._batch_size,
        "hp/learning_starts": self._learning_starts,
        "meta_modules": list(self.meta_modules.keys()),
    }
    self.learning_rate = metrics["hp/lr"]
    self._epsilon = metrics["hp/pi_epsilon"]
    self._batch_size = metrics["hp/batch_size"]
    self._learning_starts = metrics["hp/learning_starts"]

    updated_hps = {
        "step": self.steps,
        "hp/lr": self.learning_rate,
        "hp/pi_epsilon": self._epsilon,
        "hp/batch_size": self._batch_size,
        "hp/learning_starts": self._learning_starts,
        "meta_modules": list(self.meta_modules.keys()),
    }

    if any(old_hps[k] != updated_hps[k] for k in old_hps.keys()):
        self.hp_buffer = update_buffer(self.hp_buffer, updated_hps)

apply_config #

apply_config(config: Dict) -> None

Apply config to agent.

Source code in mighty/mighty_agents/base_agent.py
def apply_config(self, config: Dict) -> None:
    """Apply config to agent."""
    for n in config:
        algo_name = n.split(".")[-1]
        if hasattr(self, algo_name):
            setattr(self, algo_name, config[n])
        elif hasattr(self, "_" + algo_name):
            setattr(self, "_" + algo_name, config[n])
        elif n in ["architecture", "n_units", "n_layers", "size"]:
            pass
        else:
            print(f"Trying to set hyperparameter {algo_name} which does not exist.")

evaluate #

evaluate(eval_env: MIGHTYENV | None = None) -> Dict

Eval agent on an environment. (Full rollouts).

:param env: The environment to evaluate on :param episodes: The number of episodes to evaluate :return:

Source code in mighty/mighty_agents/base_agent.py
def evaluate(self, eval_env: MIGHTYENV | None = None) -> Dict:  # type: ignore
    """Eval agent on an environment. (Full rollouts).

    :param env: The environment to evaluate on
    :param episodes: The number of episodes to evaluate
    :return:
    """

    terminated, truncated = False, False
    options: Dict = {}
    if eval_env is None:
        eval_env = self.eval_env

    state, _ = eval_env.reset(options=options, seed=self.seed)  # type: ignore
    rewards = np.zeros(eval_env.num_envs)  # type: ignore
    steps = np.zeros(eval_env.num_envs)  # type: ignore
    mask = np.zeros(eval_env.num_envs)  # type: ignore
    while not np.all(mask):
        action = self.policy(state, evaluate=True)  # type: ignore
        state, reward, terminated, truncated, _ = eval_env.step(action)  # type: ignore
        rewards += reward * (1 - mask)
        steps += 1 * (1 - mask)
        dones = np.logical_or(terminated, truncated)
        mask = np.where(dones, 1, mask)

    if isinstance(self.eval_env, DACENV) or isinstance(self.env, CARLENV):
        instance = eval_env.instance  # type: ignore
    else:
        instance = "None"

    eval_metrics = {
        "step": self.steps,
        "seed": self.seed,
        "eval_episodes": np.array(rewards) / steps,
        "mean_eval_step_reward": np.mean(rewards) / steps,
        "mean_eval_reward": np.mean(rewards),
        "instance": instance,
        "eval_rewards": rewards,
    }
    self.eval_buffer = update_buffer(self.eval_buffer, eval_metrics)

    if self.log_wandb:
        wandb.log(eval_metrics)

    return eval_metrics

initialize_agent #

initialize_agent() -> None

General initialization of tracer and buffer for all agents.

Algorithm specific initialization like policies etc. are done in _initialize_agent

Source code in mighty/mighty_agents/base_agent.py
def initialize_agent(self) -> None:
    """General initialization of tracer and buffer for all agents.

    Algorithm specific initialization like policies etc.
    are done in _initialize_agent
    """
    self._initialize_agent()

    if isinstance(self.buffer_class, type) and issubclass(
        self.buffer_class, PrioritizedReplay
    ):
        if isinstance(self.buffer_kwargs, DictConfig):
            self.buffer_kwargs = OmegaConf.to_container(
                self.buffer_kwargs, resolve=True
            )
        # 1) Get observation-space shape
        try:
            obs_space = self.env.single_observation_space
            obs_shape = tuple(obs_space.shape)
        except Exception:
            # Fallback: call env.reset() once and infer shape from returned numpy/torch array
            first_obs, _ = self.env.reset(seed=self.seed)
            obs_shape = tuple(np.array(first_obs).shape)

        # 2) Get action-space shape (if discrete, .n is number of actions)
        action_space = self.env.single_action_space
        if hasattr(action_space, "n"):
            # Discrete action space → action_shape = () (scalar), but Q-net will expect a single integer
            # We store it as a zero-length tuple, and treat it as int later.
            action_shape = ()
        else:
            # Continuous action space, e.g. Box(shape=(3,)), so we store that tuple
            action_shape = tuple(action_space.shape)

        # 3) Overwrite the YAML placeholders (null → actual)
        self.buffer_kwargs["obs_shape"] = obs_shape
        self.buffer_kwargs["action_shape"] = action_shape

    self.buffer = self.buffer_class(**self.buffer_kwargs)  # type: ignore

make_checkpoint_dir #

make_checkpoint_dir(t: int) -> None

Checkpoint model.

:param T: Current timestep :return:

Source code in mighty/mighty_agents/base_agent.py
def make_checkpoint_dir(self, t: int) -> None:
    """Checkpoint model.

    :param T: Current timestep
    :return:
    """
    self.upper_checkpoint_dir = Path(self.output_dir) / Path("checkpoints")
    if not self.upper_checkpoint_dir.exists():
        Path(self.upper_checkpoint_dir).mkdir()
    self.checkpoint_dir = self.upper_checkpoint_dir / f"{t}"
    if not self.checkpoint_dir.exists():
        Path(self.checkpoint_dir).mkdir()

run #

run(
    n_steps: int,
    eval_every_n_steps: int = 1000,
    save_model_every_n_steps: int | None = 5000,
    env: MIGHTYENV = None,
) -> Dict

Run agent.

Source code in mighty/mighty_agents/base_agent.py
def run(  # noqa: PLR0915
    self,
    n_steps: int,
    eval_every_n_steps: int = 1_000,
    save_model_every_n_steps: int | None = 5000,
    env: MIGHTYENV = None,  # type: ignore
) -> Dict:
    """Run agent."""
    episodes = 0
    if env is not None:
        self.env = env

    logging_layout, progress, steps_task = self.make_logging_layout(n_steps)
    update_multiplier = 0

    with Live(logging_layout, refresh_per_second=10, vertical_overflow="visible"):
        steps_since_eval = 0
        steps_since_log = 0

        metrics = {
            "env": self.env,
            "vf": self.value_function,  # type: ignore
            "policy": self.policy,
            "step": self.steps,
            "hp/lr": self.learning_rate,
            "hp/pi_epsilon": self._epsilon,
            "hp/batch_size": self._batch_size,
            "hp/learning_starts": self._learning_starts,
        }

        # Reset env and initialize reward sum
        curr_s, _ = self.env.reset(seed=self.seed)  # type: ignore
        if len(curr_s.squeeze().shape) == 0:
            episode_reward = [0]
        else:
            episode_reward = np.zeros(curr_s.squeeze().shape[0])  # type: ignore

        last_episode_reward = episode_reward
        if not torch.is_tensor(last_episode_reward):
            last_episode_reward = torch.tensor(last_episode_reward).float()

        recent_episode_reward = []
        recent_step_reward = []
        recent_actions = []
        evaluation_reward = []

        # Start logging
        eval_curve = [0]
        learning_curve = [0]
        curve_xs = [0]
        progress.update(steps_task, visible=True)
        logging_layout["lower"]["left"].update(
            self.get_plot(curve_xs, learning_curve, "Training Reward")
        )
        logging_layout["lower"]["right"].update(
            self.get_plot(curve_xs, eval_curve, "Evaluation Reward")
        )

        # Main loop: rollouts, training and evaluation
        while self.steps < n_steps:
            metrics["episode_reward"] = episode_reward

            action, log_prob = self.step(curr_s, metrics)
            # step the env as usual
            next_s, reward, terminated, truncated, infos = self.env.step(action)

            # decide which samples are true “done”
            replay_dones = terminated          # physics‐failure only
            dones = np.logical_or(terminated, truncated)


            # Overwrite next_s on truncation
            # Based on https://github.com/DLR-RM/stable-baselines3/issues/284    
            real_next_s = next_s.copy()
            # infos["final_observation"] is a list/array of the last real obs
            for i, tr in enumerate(truncated):
                if tr:
                    real_next_s[i] = infos["final_observation"][i]
            episode_reward += reward

            # Log everything
            t = {
                "seed": self.seed,
                "step": self.steps,
                "reward": reward,
                "action": action,
                "state": curr_s,
                "next_state": real_next_s,
                "terminated": terminated.astype(int),
                "truncated": truncated.astype(int),
                "dones": replay_dones.astype(int),
                "mean_episode_reward": last_episode_reward.mean()
                .cpu()
                .numpy()
                .item(),
            }
            metrics["log_prob"] = log_prob.detach().cpu().numpy()
            metrics["episode_reward"] = episode_reward
            metrics["transition"] = t

            recent_actions.append(np.mean(action))
            if len(recent_actions) > 100:
                recent_actions.pop(0)

            for k in self.meta_modules:
                self.meta_modules[k].post_step(metrics)

            transition_metrics = self.process_transition(
                metrics["transition"]["state"],
                metrics["transition"]["action"],
                metrics["transition"]["reward"],
                metrics["transition"]["next_state"],
                metrics["transition"]["dones"],
                metrics["log_prob"],
                metrics,
            )
            metrics.update(transition_metrics)
            self.result_buffer = update_buffer(self.result_buffer, t)

            if self.log_wandb:
                wandb.log(t)

            self.steps += len(action)
            metrics["step"] = self.steps
            steps_since_eval += len(action)
            steps_since_log += len(action)
            for _ in range(len(action)):
                progress.advance(steps_task)

            # Update agent
            if (
                len(self.buffer) >= self._batch_size  # type: ignore
                and self.steps >= self._learning_starts
            ):
                update_kwargs = {"next_s": next_s, "dones": dones}
                metrics = self.update(metrics, update_kwargs)

            # End step
            self.last_state = curr_s
            curr_s = next_s

            # Evaluate
            if eval_every_n_steps and steps_since_eval >= eval_every_n_steps:
                steps_since_eval = 0
                eval_metrics = self.evaluate()
                evaluation_reward = eval_metrics["eval_rewards"]

            # Log to command line via rich layout
            if self.steps >= 1000 * update_multiplier:
                metrics_table = self.make_logging_table(
                    self.steps,
                    recent_episode_reward,
                    recent_step_reward,
                    evaluation_reward,
                    recent_actions,
                )
                logging_layout["middle"]["left"].update(metrics_table)
                eval_curve.append(np.mean(evaluation_reward))
                learning_curve.append(np.mean(recent_episode_reward))
                curve_xs.append(self.steps)

                logging_layout["lower"]["left"].update(
                    self.get_plot(curve_xs, learning_curve, "Training Reward")
                )
                logging_layout["lower"]["right"].update(
                    self.get_plot(curve_xs, eval_curve, "Evaluation Reward")
                )
                update_multiplier += 1

            # Save model & metrics
            if (
                save_model_every_n_steps
                and steps_since_log >= save_model_every_n_steps
            ):
                steps_since_log = 0
                self.save(self.steps)
                log_to_file(
                    self.output_dir,
                    self.result_buffer,
                    self.hp_buffer,
                    self.eval_buffer,
                    self.loss_buffer,
                )

            # Perform resets as necessary
            if np.any(dones):
                last_episode_reward = np.where(  # type: ignore
                    dones, episode_reward, last_episode_reward
                )
                recent_episode_reward.append(np.mean(last_episode_reward))
                recent_step_reward.append(
                    np.mean(last_episode_reward) / len(last_episode_reward)
                )
                last_episode_reward = torch.tensor(last_episode_reward).float()
                if len(recent_episode_reward) > 10:
                    recent_episode_reward.pop(0)
                    recent_step_reward.pop(0)
                episode_reward = np.where(dones, 0, episode_reward)  # type: ignore
                # End episode
                if isinstance(self.env, DACENV) or isinstance(self.env, CARLENV):
                    instance = self.env.instance  # type: ignore
                else:
                    instance = None
                metrics["instance"] = instance
                episodes += 1
                for k in self.meta_modules:
                    self.meta_modules[k].post_episode(metrics)

                if "rollout_values" in metrics:
                    del metrics["rollout_values"]

                if "rollout_logits" in metrics:
                    del metrics["rollout_logits"]

                # Meta Module hooks
                for k in self.meta_modules:
                    self.meta_modules[k].pre_episode(metrics)

    # Final logging
    log_to_file(
        self.output_dir,
        self.result_buffer,
        self.hp_buffer,
        self.eval_buffer,
        self.loss_buffer,
    )
    return metrics

update #

update(metrics: Dict, update_kwargs: Dict) -> Dict

Update agent.

Source code in mighty/mighty_agents/base_agent.py
def update(self, metrics: Dict, update_kwargs: Dict) -> Dict:
    """Update agent."""
    for k in self.meta_modules:
        self.meta_modules[k].pre_update(metrics)

    batches = []
    for batches_left in reversed(range(self.n_gradient_steps)):
        batch = self.buffer.sample(self._batch_size)
        agent_update_metrics = self.update_agent(
            transition_batch=batch, batches_left=batches_left, **update_kwargs
        )

        metrics.update(agent_update_metrics)
        metrics["step"] = self.steps

        if self.log_wandb:
            log_to_wandb(metrics=metrics)

        metrics["env"] = self.env
        metrics["vf"] = self.value_function  # type: ignore
        metrics["policy"] = self.policy
        batches.append(batch)

    metrics["update_batches"] = batches
    for k in self.meta_modules:
        self.meta_modules[k].post_update(metrics)
    del metrics["update_batches"]
    return metrics