Skip to content

Mighty runner

mighty.mighty_runners.mighty_runner #

MightyRunner #

MightyRunner(cfg: DictConfig)

Bases: ABC

Source code in mighty/mighty_runners/mighty_runner.py
def __init__(self, cfg: DictConfig) -> None:
    """Parse config and run Mighty agent."""
    output_dir = Path(cfg.output_dir) / f"{cfg.experiment_name}_{cfg.seed}"
    if not output_dir.exists():
        output_dir.mkdir(parents=True)

    # Check whether env is from DACBench, CARL or gym
    # Make train and eval env
    env, base_eval_env, eval_default = make_mighty_env(cfg)

    # TODO: move wrapping to env handling?
    wrapper_classes = []
    for w in cfg.env_wrappers:
        wkwargs = cfg.wrapper_kwargs if "wrapper_kwargs" in cfg else {}
        cls = get_class(w)
        env = cls(env, **wkwargs)
        wrapper_classes.append((cls, wkwargs))

    def wrap_eval():  # type: ignore
        wrapped_env = base_eval_env()
        for cls, wkwargs in wrapper_classes:
            wrapped_env = cls(wrapped_env, **wkwargs)
        return wrapped_env

    eval_env = wrap_eval()

    # Setup agent
    # TODO: agent currently needs more than just algo and algo_kwargs (see logging)
    agent_class = get_agent_class(cfg.algorithm)
    args_agent = dict(cfg.algorithm_kwargs)
    self.agent = agent_class(  # type: ignore
        env=env,
        eval_env=eval_env,
        output_dir=output_dir,
        seed=cfg.seed,
        **args_agent,
    )

    self.eval_every_n_steps = cfg.eval_every_n_steps
    self.num_steps = cfg.num_steps

    # Load checkpoint if one is given
    if cfg.checkpoint is not None:
        self.agent.load(cfg.checkpoint)
        logging.info("#" * 80)
        logging.info(f"Loading checkpoint at {cfg.checkpoint}")

    # Train
    logging.info("#" * 80)
    logging.info(f'Using agent type "{self.agent}" to learn')
    logging.info("#" * 80)