Skip to content

Dqn

mighty.mighty_agents.dqn #

DQN agent.

MightyDQNAgent #

MightyDQNAgent(
    output_dir: str,
    env: MIGHTYENV,
    seed: int | None = None,
    eval_env: MIGHTYENV = None,
    learning_rate: float = 0.01,
    gamma: float = 0.9,
    epsilon: float = 0.1,
    batch_size: int = 64,
    learning_starts: int = 1,
    render_progress: bool = True,
    log_wandb: bool = False,
    wandb_kwargs: dict | None = None,
    replay_buffer_class: str
    | DictConfig
    | type[MightyReplay]
    | None = None,
    replay_buffer_kwargs: TypeKwargs | None = None,
    meta_methods: list[str | type] | None = None,
    meta_kwargs: list[TypeKwargs] | None = None,
    use_target: bool = True,
    n_units: int = 8,
    soft_update_weight: float = 0.01,
    policy_class: str
    | DictConfig
    | type[MightyExplorationPolicy]
    | None = None,
    policy_kwargs: TypeKwargs | None = None,
    q_class: str | DictConfig | type[DQN] | None = None,
    q_kwargs: TypeKwargs | None = None,
    td_update_class: type[QLearning] = QLearning,
    td_update_kwargs: TypeKwargs | None = None,
    save_replay: bool = False,
)

Bases: MightyAgent

Mighty DQN agent.

This agent implements the DQN algorithm and extension as first proposed in "Playing Atari with Deep Reinforcement Learning" by Mnih et al. in 2013. DDQN was proposed by van Hasselt et al. in 2016's "Deep Reinforcement Learning with Double Q-learning". Like all Mighty agents, it's supposed to be called via the train method. By default, this agent uses an epsilon-greedy policy.

Creates all relevant class variables and calls agent-specific init function

:param env: Train environment :param eval_env: Evaluation environment :param learning_rate: Learning rate for training :param epsilon: Exploration factor for training :param batch_size: Batch size for training :param render_progress: Render progress :param log_tensorboard: Log to tensorboard as well as to file :param replay_buffer_class: Replay buffer class from coax replay buffers :param replay_buffer_kwargs: Arguments for the replay buffer :param tracer_class: Reward tracing class from coax tracers :param tracer_kwargs: Arguments for the reward tracer :param n_units: Number of units for Q network :param soft_update_weight: Size of soft updates for target network :param policy_class: Policy class from coax value-based policies :param policy_kwargs: Arguments for the policy :param td_update_class: Kind of TD update used from coax TD updates :param td_update_kwargs: Arguments for the TD update :return:

Source code in mighty/mighty_agents/dqn.py
def __init__(
    self,
    output_dir: str,
    # MightyAgent Args
    env: MIGHTYENV,  # type: ignore
    seed: int | None = None,
    eval_env: MIGHTYENV = None,  # type: ignore
    learning_rate: float = 0.01,
    gamma: float = 0.9,
    epsilon: float = 0.1,
    batch_size: int = 64,
    learning_starts: int = 1,
    render_progress: bool = True,
    log_wandb: bool = False,
    wandb_kwargs: dict | None = None,
    replay_buffer_class: str | DictConfig | type[MightyReplay] | None = None,
    replay_buffer_kwargs: TypeKwargs | None = None,
    meta_methods: list[str | type] | None = None,
    meta_kwargs: list[TypeKwargs] | None = None,
    # DDQN Specific Args
    use_target: bool = True,
    n_units: int = 8,
    soft_update_weight: float = 0.01,
    policy_class: str | DictConfig | type[MightyExplorationPolicy] | None = None,
    policy_kwargs: TypeKwargs | None = None,
    q_class: str | DictConfig | type[DQN] | None = None,
    q_kwargs: TypeKwargs | None = None,
    td_update_class: type[QLearning] = QLearning,
    td_update_kwargs: TypeKwargs | None = None,
    save_replay: bool = False,
):
    # FIXME: the arguments are not complete. Double check all classes.
    """DQN initialization.

    Creates all relevant class variables and calls agent-specific init function

    :param env: Train environment
    :param eval_env: Evaluation environment
    :param learning_rate: Learning rate for training
    :param epsilon: Exploration factor for training
    :param batch_size: Batch size for training
    :param render_progress: Render progress
    :param log_tensorboard: Log to tensorboard as well as to file
    :param replay_buffer_class: Replay buffer class from coax replay buffers
    :param replay_buffer_kwargs: Arguments for the replay buffer
    :param tracer_class: Reward tracing class from coax tracers
    :param tracer_kwargs: Arguments for the reward tracer
    :param n_units: Number of units for Q network
    :param soft_update_weight: Size of soft updates for target network
    :param policy_class: Policy class from coax value-based policies
    :param policy_kwargs: Arguments for the policy
    :param td_update_class: Kind of TD update used from coax TD updates
    :param td_update_kwargs: Arguments for the TD update
    :return:
    """
    if meta_kwargs is None:
        meta_kwargs = []
    if meta_methods is None:
        meta_methods = []
    if wandb_kwargs is None:
        wandb_kwargs = {}
    self.n_units = n_units
    assert 0.0 <= soft_update_weight <= 1.0  # noqa: PLR2004
    self.soft_update_weight = soft_update_weight

    # Placeholder variables which are filled in self.initialize_agent
    self.q: DQN | None = None
    self.policy: MightyExplorationPolicy | None = None
    self.q_target: DQN | None = None
    self.qlearning: QLearning | None = None
    self.use_target = use_target

    # Q-function Class
    q_class = retrieve_class(cls=q_class, default_cls=DQN)  # type: ignore
    if q_kwargs is None:
        q_kwargs = {"n_layers": 0}  # type: ignore
    self.q_class = q_class
    self.q_kwargs = q_kwargs

    # Policy Class
    policy_class = retrieve_class(cls=policy_class, default_cls=EpsilonGreedy)  # type: ignore
    if policy_kwargs is None:
        policy_kwargs = {"epsilon": 0.1}  # type: ignore
    self.policy_class = policy_class
    self.policy_kwargs = policy_kwargs

    self.td_update_class = retrieve_class(
        cls=td_update_class, default_cls=DoubleQLearning
    )
    if td_update_kwargs is None:
        td_update_kwargs = {"gamma": gamma}  # type: ignore
    self.td_update_kwargs = td_update_kwargs
    self.save_replay = save_replay

    super().__init__(
        env=env,
        output_dir=output_dir,
        seed=seed,
        eval_env=eval_env,
        learning_rate=learning_rate,
        epsilon=epsilon,
        batch_size=batch_size,
        learning_starts=learning_starts,
        render_progress=render_progress,
        log_wandb=log_wandb,
        wandb_kwargs=wandb_kwargs,
        replay_buffer_class=replay_buffer_class,
        replay_buffer_kwargs=replay_buffer_kwargs,
        meta_methods=meta_methods,
        meta_kwargs=meta_kwargs,
    )

    self.loss_buffer = {
        "Update/loss": [],
        "Update/td_errors": [],
        "step": [],
    }

parameters property #

parameters: List

Q-function parameters.

value_function property #

value_function: DQN

Q-function.

__del__ #

__del__() -> None

Close wandb upon deletion.

Source code in mighty/mighty_agents/base_agent.py
def __del__(self) -> None:
    """Close wandb upon deletion."""
    self.env.close()  # type: ignore
    if self.log_wandb:
        wandb.finish()

adapt_hps #

adapt_hps(metrics: Dict) -> None

Set hyperparameters.

Source code in mighty/mighty_agents/dqn.py
def adapt_hps(self, metrics: Dict) -> None:
    """Set hyperparameters."""
    super().adapt_hps(metrics)
    if "hp/soft_update_weight" in metrics:
        self.soft_update_weight = metrics["hp/soft_update_weight"]
    for g in self.qlearning.optimizer.param_groups:  # type: ignore
        g["lr"] = self.learning_rate

apply_config #

apply_config(config: Dict) -> None

Apply config to agent.

Source code in mighty/mighty_agents/base_agent.py
def apply_config(self, config: Dict) -> None:
    """Apply config to agent."""
    for n in config:
        algo_name = n.split(".")[-1]
        if hasattr(self, algo_name):
            setattr(self, algo_name, config[n])
        elif hasattr(self, "_" + algo_name):
            setattr(self, "_" + algo_name, config[n])
        elif n in ["architecture", "n_units", "n_layers", "size"]:
            pass
        else:
            print(f"Trying to set hyperparameter {algo_name} which does not exist.")

evaluate #

evaluate(eval_env: MIGHTYENV | None = None) -> Dict

Eval agent on an environment. (Full rollouts).

:param env: The environment to evaluate on :param episodes: The number of episodes to evaluate :return:

Source code in mighty/mighty_agents/base_agent.py
def evaluate(self, eval_env: MIGHTYENV | None = None) -> Dict:  # type: ignore
    """Eval agent on an environment. (Full rollouts).

    :param env: The environment to evaluate on
    :param episodes: The number of episodes to evaluate
    :return:
    """

    terminated, truncated = False, False
    options: Dict = {}
    if eval_env is None:
        eval_env = self.eval_env

    state, _ = eval_env.reset(options=options)  # type: ignore
    rewards = np.zeros(eval_env.num_envs)  # type: ignore
    steps = np.zeros(eval_env.num_envs)  # type: ignore
    mask = np.zeros(eval_env.num_envs)  # type: ignore
    while not np.all(mask):
        action = self.policy(state, evaluate=True)  # type: ignore
        state, reward, terminated, truncated, _ = eval_env.step(action)  # type: ignore
        rewards += reward * (1 - mask)
        steps += 1 * (1 - mask)
        dones = np.logical_or(terminated, truncated)
        mask = np.where(dones, 1, mask)

    eval_env.close()  # type: ignore

    if isinstance(self.eval_env, DACENV) or isinstance(self.env, CARLENV):
        instance = eval_env.instance  # type: ignore
    else:
        instance = "None"

    eval_metrics = {
        "step": self.steps,
        "seed": self.seed,
        "eval_episodes": np.array(rewards) / steps,
        "mean_eval_step_reward": np.mean(rewards) / steps,
        "mean_eval_reward": np.mean(rewards),
        "instance": instance,
    }
    self.eval_buffer = update_buffer(self.eval_buffer, eval_metrics)

    # FIXME: this is the ugly I'm talking about
    if self.verbose:
        print("")
        print(
            "------------------------------------------------------------------------------"
        )
        print(
            f"""Evaluation performance after {self.steps} steps:
            {np.round(np.mean(rewards), decimals=2)}"""
        )
        print(
            f"""Evaluation performance per step after {self.steps} steps:
            {np.round(np.mean(rewards / steps), decimals=2)}"""
        )
        print(
            "------------------------------------------------------------------------------"
        )
        print("")

    if self.log_wandb:
        wandb.log(eval_metrics)

    return eval_metrics

initialize_agent #

initialize_agent() -> None

General initialization of tracer and buffer for all agents.

Algorithm specific initialization like policies etc. are done in _initialize_agent

Source code in mighty/mighty_agents/base_agent.py
def initialize_agent(self) -> None:
    """General initialization of tracer and buffer for all agents.

    Algorithm specific initialization like policies etc.
    are done in _initialize_agent
    """
    self._initialize_agent()
    self.buffer = self.buffer_class(**self.buffer_kwargs)  # type: ignore

load #

load(path: str) -> None

Set the internal state of the agent, e.g. after loading.

Source code in mighty/mighty_agents/dqn.py
def load(self, path: str) -> None:
    """Set the internal state of the agent, e.g. after loading."""
    base_path = Path(path)
    q_path = base_path / "q.pt"
    q_state = torch.load(q_path)
    self.q.load_state_dict(q_state)  # type: ignore

    if self.q_target is not None:
        target_path = base_path / "q_target.pt"
        target_state = torch.load(target_path)
        self.q_target.load_state_dict(target_state)

    optimizer_path = base_path / "optimizer.pkl"
    optimizer_state_dict = torch.load(optimizer_path)["optimizer_state"]
    self.qlearning.optimizer.load_state_dict(optimizer_state_dict)  # type: ignore

    replay_path = base_path / "replay.pkl"
    if replay_path.exists():
        self.buffer = dill.loads(replay_path)
    if self.verbose:
        print(f"Loaded checkpoint at {path}")

make_checkpoint_dir #

make_checkpoint_dir(t: int) -> None

Checkpoint model.

:param T: Current timestep :return:

Source code in mighty/mighty_agents/base_agent.py
def make_checkpoint_dir(self, t: int) -> None:
    """Checkpoint model.

    :param T: Current timestep
    :return:
    """
    self.upper_checkpoint_dir = Path(self.output_dir) / Path("checkpoints")
    if not self.upper_checkpoint_dir.exists():
        Path(self.upper_checkpoint_dir).mkdir()
    self.checkpoint_dir = self.upper_checkpoint_dir / f"{t}"
    if not self.checkpoint_dir.exists():
        Path(self.checkpoint_dir).mkdir()

run #

run(
    n_steps: int,
    eval_every_n_steps: int = 1000,
    human_log_every_n_steps: int = 5000,
    save_model_every_n_steps: int | None = 5000,
    env: MIGHTYENV = None,
) -> Dict

Run agent.

Source code in mighty/mighty_agents/base_agent.py
def run(  # noqa: PLR0915
    self,
    n_steps: int,
    eval_every_n_steps: int = 1_000,
    human_log_every_n_steps: int = 5000,
    save_model_every_n_steps: int | None = 5000,
    env: MIGHTYENV = None,  # type: ignore
) -> Dict:
    """Run agent."""
    episodes = 0
    if env is not None:
        self.env = env
    # FIXME: can we add the eval result here? Else the evals spam the command line in a pretty ugly way
    with Progress(
        "[progress.description]{task.description}",
        BarColumn(),
        "[progress.percentage]{task.percentage:>3.0f}%",
        "Remaining:",
        TimeRemainingColumn(),
        "Elapsed:",
        TimeElapsedColumn(),
        disable=not self.render_progress,
    ) as progress:
        steps_task = progress.add_task(
            "Train Steps",
            total=n_steps - self.steps,
            start=False,
            visible=False,
        )
        steps_since_eval = 0
        progress.start_task(steps_task)
        # FIXME: this is more of a question: are there cases where we don't want to reset this completely?
        # I can't think of any, can you? If yes, we should maybe add this as an optional argument
        metrics = {
            "env": self.env,
            "vf": self.value_function,  # type: ignore
            "policy": self.policy,
            "step": self.steps,
            "hp/lr": self.learning_rate,
            "hp/pi_epsilon": self._epsilon,
            "hp/batch_size": self._batch_size,
            "hp/learning_starts": self._learning_starts,
        }

        # Reset env and initialize reward sum
        curr_s, _ = self.env.reset()  # type: ignore
        if len(curr_s.squeeze().shape) == 0:
            episode_reward = [0]
        else:
            episode_reward = np.zeros(curr_s.squeeze().shape[0])  # type: ignore

        last_episode_reward = episode_reward
        if not torch.is_tensor(last_episode_reward):
            last_episode_reward = torch.tensor(last_episode_reward).float()
        progress.update(steps_task, visible=True)

        # Main loop: rollouts, training and evaluation
        while self.steps < n_steps:
            metrics["episode_reward"] = episode_reward

            # TODO Remove
            progress.stop()

            action, log_prob = self.step(curr_s, metrics)

            next_s, reward, terminated, truncated, _ = self.env.step(action)  # type: ignore
            dones = np.logical_or(terminated, truncated)

            transition_metrics = self.process_transition(
                curr_s, action, reward, next_s, dones, log_prob, metrics
            )

            metrics.update(transition_metrics)

            episode_reward += reward

            # Log everything
            t = {
                "seed": self.seed,
                "step": self.steps,
                "reward": reward,
                "action": action,
                "state": curr_s,
                "next_state": next_s,
                "terminated": terminated.astype(int),
                "truncated": truncated.astype(int),
                "mean_episode_reward": last_episode_reward.mean(),
            }
            metrics["episode_reward"] = episode_reward
            self.result_buffer = update_buffer(self.result_buffer, t)

            if self.log_wandb:
                wandb.log(t)

            for k in self.meta_modules:
                self.meta_modules[k].post_step(metrics)

            self.steps += len(action)
            metrics["step"] = self.steps
            steps_since_eval += len(action)
            for _ in range(len(action)):
                progress.advance(steps_task)

            # Update agent
            if (
                len(self.buffer) >= self._batch_size  # type: ignore
                and self.steps >= self._learning_starts
            ):
                update_kwargs = {"next_s": next_s, "dones": dones}

                metrics = self.update(metrics, update_kwargs)

            # End step
            self.last_state = curr_s
            curr_s = next_s

            # Evaluate
            if eval_every_n_steps and steps_since_eval >= eval_every_n_steps:
                steps_since_eval = 0
                self.evaluate()

            # Log to command line
            if self.steps % human_log_every_n_steps == 0 and self.verbose:
                mean_last_ep_reward = np.round(
                    np.mean(last_episode_reward), decimals=2
                )
                mean_last_step_reward = np.round(
                    np.mean(mean_last_ep_reward / len(last_episode_reward)),
                    decimals=2,
                )
                print(
                    f"""Steps: {self.steps}, Latest Episode Reward: {mean_last_ep_reward}, Latest Step Reward: {mean_last_step_reward}"""  # noqa: E501
                )

            # Save
            if (
                save_model_every_n_steps
                and self.steps % save_model_every_n_steps == 0
            ):
                self.save(self.steps)
                log_to_file(
                    self.output_dir,
                    self.result_buffer,
                    self.hp_buffer,
                    self.eval_buffer,
                    self.loss_buffer,
                )

            if np.any(dones):
                last_episode_reward = np.where(  # type: ignore
                    dones, episode_reward, last_episode_reward
                )
                episode_reward = np.where(dones, 0, episode_reward)  # type: ignore
                # End episode
                if isinstance(self.env, DACENV) or isinstance(self.env, CARLENV):
                    instance = self.env.instance  # type: ignore
                else:
                    instance = None
                metrics["instance"] = instance
                episodes += 1
                for k in self.meta_modules:
                    self.meta_modules[k].post_episode(metrics)

                # Remove rollout data from last episode
                # TODO: only do this for finished envs
                # FIXME: open todo, I think we need to use dones as a mask here
                # Proposed fix: metrics[k][:, dones] = 0
                # I don't think this is correct masking and I think we have to check the size of zeros
                for k in list(metrics.keys()):
                    if "rollout" in k:
                        del metrics[k]

                # Meta Module hooks
                for k in self.meta_modules:
                    self.meta_modules[k].pre_episode(metrics)
    log_to_file(
        self.output_dir,
        self.result_buffer,
        self.hp_buffer,
        self.eval_buffer,
        self.loss_buffer,
    )
    return metrics

save #

save(t: int) -> None

Return current agent state, e.g. for saving.

For DQN, this consists of: - the Q network parameters - the Q network function state - the target network parameters - the target network function state

:return: Agent state

Source code in mighty/mighty_agents/dqn.py
def save(self, t: int) -> None:
    """Return current agent state, e.g. for saving.

    For DQN, this consists of:
    - the Q network parameters
    - the Q network function state
    - the target network parameters
    - the target network function state

    :return: Agent state
    """
    super().make_checkpoint_dir(t)
    # Save q parameters
    q_path = self.checkpoint_dir / "q.pt"
    torch.save(self.q.state_dict(), q_path)  # type: ignore

    # Save target parameters
    if self.q_target is not None:
        target_path = self.checkpoint_dir / "q_target.pt"
        torch.save(self.q_target.state_dict(), target_path)

    # Save optimizer state
    optimizer_path = self.checkpoint_dir / "optimizer.pkl"
    torch.save(
        {"optimizer_state": self.qlearning.optimizer.state_dict()},  # type: ignore
        optimizer_path,
    )

    # Save replay buffer
    if self.save_replay:
        replay_path = self.checkpoint_dir / "replay.pkl"
        self.buffer.save(replay_path)  # type: ignore

    if self.verbose:
        print(f"Saved checkpoint at {self.checkpoint_dir}")

update #

update(metrics: Dict, update_kwargs: Dict) -> Dict

Update agent.

Source code in mighty/mighty_agents/base_agent.py
def update(self, metrics: Dict, update_kwargs: Dict) -> Dict:
    """Update agent."""
    for k in self.meta_modules:
        self.meta_modules[k].pre_update(metrics)

    agent_update_metrics = self.update_agent(**update_kwargs)
    metrics.update(agent_update_metrics)
    metrics = {k: np.array(v) for k, v in metrics.items()}
    metrics["step"] = self.steps

    if self.log_wandb:
        wandb.log(metrics)

    metrics["env"] = self.env
    metrics["vf"] = self.value_function  # type: ignore
    metrics["policy"] = self.policy
    for k in self.meta_modules:
        self.meta_modules[k].post_update(metrics)
    return metrics

update_agent #

update_agent(**kwargs) -> Any

Compute and apply TD update.

:param step: Current training step :return:

Source code in mighty/mighty_agents/dqn.py
def update_agent(self, **kwargs) -> Any:  # type: ignore
    """Compute and apply TD update.

    :param step: Current training step
    :return:
    """

    transition_batch = self.buffer.sample(batch_size=self._batch_size)  # type: ignore
    preds, targets = self.qlearning.get_targets(  # type: ignore
        transition_batch, self.q, self.q_target
    )

    metrics_q = self.qlearning.apply_update(preds, targets)  # type: ignore
    metrics_q["Update/td_targets"] = targets.detach().numpy()
    metrics_q["Update/td_errors"] = (targets - preds).detach().numpy()
    loss_stats = {
        "step": self.steps,
        "Update/loss": metrics_q["Update/loss"],
        "Update/td_errors": metrics_q["Update/td_errors"].mean().item(),
        "batch_predictions": preds.mean(axis=1).detach().numpy().tolist(),
    }
    self.loss_buffer = update_buffer(self.loss_buffer, loss_stats)

    # sync target model
    if self.q_target is not None:
        for param, target_param in zip(
            self.q.parameters(),  # type: ignore
            self.q_target.parameters(),
            strict=False,
        ):
            target_param.data.copy_(
                self.soft_update_weight * param.data
                + (1 - self.soft_update_weight) * target_param.data
            )

    return metrics_q