Source code for dacbench.wrappers.multidiscrete_action_wrapper

"""Wrapper for casting MultiDiscrete action spaces to Discrete."""

from __future__ import annotations

import itertools

import numpy as np
from gymnasium import Wrapper, spaces


[docs] class MultiDiscreteActionWrapper(Wrapper): """Wrapper to cast MultiDiscrete action spaces to Discrete. This should improve usability with standard RL libraries. """ def __init__(self, env): """Initialize wrapper. Parameters ---------- env : gym.Env Environment to wrap """ super().__init__(env) self.n_actions = len(self.env.action_space) self.action_space = spaces.Discrete(np.prod(self.env.action_space)) self.action_mapper = {} for idx, prod_idx in zip( range(np.prod(self.env.action_space.nvec)), itertools.product(*[np.arange(val) for val in self.env.action_space]), strict=False, ): self.action_mapper[idx] = prod_idx
[docs] def step(self, action): """Maps discrete action value to array.""" action = self.action_mapper[action] return self.env.step(action)