"""Planning environment from
"Learning Heuristic Selection with Dynamic Algorithm Configuration"
by David Speck, André Biedenkapp, Frank Hutter, Robert Mattmüller und Marius Lindauer.
Original environment authors: David Speck, André Biedenkapp.
"""
from __future__ import annotations
import os
import socket
import subprocess
import time
from copy import deepcopy
from enum import Enum
from pathlib import Path
import numpy as np
from dacbench import AbstractEnv
[docs]
class StateType(Enum):
"""Class to define numbers for state types."""
RAW = 1
DIFF = 2
ABSDIFF = 3
NORMAL = 4
NORMDIFF = 5
NORMABSDIFF = 6
[docs]
class FastDownwardEnv(AbstractEnv):
"""Environment to control Solver Heuristics of FastDownward."""
def __init__(self, config):
"""Initialize FD Env.
Parameters
-------
config : objdict
Environment configuration
"""
super().__init__(config)
self._heuristic_state_features = [
"Average Value", # 'Dead Ends Reliable',
"Max Value",
"Min Value",
"Open List Entries",
"Varianz",
]
self._general_state_features = [
# 'evaluated_states', 'evaluations', 'expanded_states',
# 'generated_ops',
# 'generated_states', 'num_variables',
# 'registered_states', 'reopened_states',
# "cg_num_eff_to_eff", "cg_num_eff_to_pre", "cg_num_pre_to_eff"
]
total_state_features = len(config.heuristics) * len(
self._heuristic_state_features
)
self._use_gsi = config.use_general_state_info
if config.use_general_state_info:
total_state_features += len(self._general_state_features)
self.__skip_transform = [False for _ in range(total_state_features)]
if config.use_general_state_info:
self.__skip_transform[4] = True # skip num_variables transform
self.__skip_transform[7] = True
self.__skip_transform[8] = True
self.__skip_transform[9] = True
self.heuristics = config.heuristics
self.host = config.host
self._port = config.get("port", 0)
if config["parallel"]:
self.port = 0
self.fd_seed = config.fd_seed
self.control_interval = config.control_interval
if config.fd_logs is None:
self.logpath_out = os.devnull
self.logpath_err = os.devnull
else:
self.logpath_out = Path(config.fd_logs) / "fdout.txt"
self.logpath_err = Path(config.fd_logs) / "fderr.txt"
self.fd_path = config.fd_path
self.fd = None
if "domain_file" in config:
self.domain_file = config["domain_file"]
self.socket = None
self.conn = None
self._prev_state = None
self.num_steps = config.num_steps
self.__state_type = StateType(config.state_type)
self.__norm_vals = []
self._config_dir = config.config_dir
self._port_file_id = config.port_file_id
self._transformation_func = None
# create state transformation function with inputs
# (current state, previous state, normalization values)
if self.__state_type == StateType.DIFF:
self._transformation_func = lambda x, y, z, skip: x - y if not skip else x
elif self.__state_type == StateType.ABSDIFF:
self._transformation_func = lambda x, y, z, skip: (
abs(x - y) if not skip else x
)
elif self.__state_type == StateType.NORMAL:
self._transformation_func = lambda x, y, z, skip: (
FastDownwardEnv._save_div(x, z) if not skip else x
)
elif self.__state_type == StateType.NORMDIFF:
self._transformation_func = lambda x, y, z, skip: (
FastDownwardEnv._save_div(x, z) - FastDownwardEnv._save_div(y, z)
if not skip
else x
)
elif self.__state_type == StateType.NORMABSDIFF:
self._transformation_func = lambda x, y, z, skip: (
abs(FastDownwardEnv._save_div(x, z) - FastDownwardEnv._save_div(y, z))
if not skip
else x
)
self.max_rand_steps = config.max_rand_steps
self.__start_time = None
self.done = True # Starts as true as the expected behavior is that
# before normal resets an episode was done.
@property
def port(self):
"""Port function."""
if self._port == 0:
if self.socket is None:
raise ValueError(
"Automatic port selection enabled. Port not know at the moment"
)
_, port = self.socket.getsockname()
else:
port = self._port
return port
@port.setter
def port(self, port):
self._port = port
@property
def _argstring(self):
# if a socket is bound to 0 it will automatically choose a free port
return (
f"rl_eager(rl([{''.join(f'{h},' for h in self.heuristics)[:-1]}],"
f"random_seed={self.fd_seed}),rl_control_interval={self.control_interval},rl_client_port={self.port})"
)
@staticmethod
def _save_div(a, b):
"""Helper method for safe division.
Parameters
----------
a : list or np.array
values to be divided
b : list or np.array
values to divide by
Returns:
-------
np.array
Division result
"""
return np.divide(a, b, out=np.zeros_like(a), where=b != 0)
[docs]
def send_msg(self, msg: bytes):
"""Send message and prepend the message size.
Based on comment from SO see [1]
[1] https://stackoverflow.com/a/17668009
Parameters
----------
msg : bytes
The message as byte
"""
# Prefix each message with a 4-byte length (network byte order)
msg = str.encode(f"{len(msg):>04d}") + msg
self.conn.sendall(msg)
[docs]
def recv_msg(self):
"""Recieve a whole message. The message has to be prepended with its total size
Based on comment from SO see [1].
Returns:
----------
bytes
The message as byte
"""
# Read message length and unpack it into an integer
raw_msglen = self.recvall(4)
if not raw_msglen:
return None
msglen = int(raw_msglen.decode())
# Read the message data
return self.recvall(msglen)
[docs]
def recvall(self, n: int):
"""Given we know the size we want to recieve,
we can recieve that amount of bytes.
Based on comment from SO see [1].
Parameters
---------
n: int
Number of bytes to expect in the data
Returns:
----------
bytes
The message as byte
"""
# Helper function to recv n bytes or return None if EOF is hit
data = b""
while len(data) < n:
packet = self.conn.recv(n - len(data))
if not packet:
return None
data += packet
return data
def _process_data(self):
"""Split received json into state reward and done.
Returns:
----------
np.array, float, bool
state, reward, done
"""
msg = self.recv_msg().decode()
# print("----------------------------")
# print(msg)
# print("=>")
msg = msg.replace("-inf", "0")
msg = msg.replace("inf", "0")
# print(msg)
data = eval(msg) # noqa: S307
r = data["reward"]
done = data["done"]
del data["reward"]
del data["done"]
state = []
if self._use_gsi:
for feature in self._general_state_features:
state.append(data[feature])
for heuristic_id in range(len(self.heuristics)): # process heuristic data
for feature in self._heuristic_state_features:
state.append(data["%d" % heuristic_id][feature])
if self._prev_state is None:
self.__norm_vals = deepcopy(state)
self._prev_state = deepcopy(state)
if (
self.__state_type != StateType.RAW
): # Transform state to DIFF state or normalize
tmp_state = state
state = list(
map(
self._transformation_func,
state,
self._prev_state,
self.__norm_vals,
self.__skip_transform,
)
)
self._prev_state = tmp_state
return np.array(state), r, done
[docs]
def step(self, action: int | list[int]):
"""Environment step.
Parameters
---------
action: typing.Union[int, List[int]]
Parameter(s) to apply
Returns:
----------
np.array, float, bool, bool, dict
state, reward, terminated, truncated, info
"""
self.done = super().step_()
if not np.issubdtype(
type(action), np.integer
): # check for core int and any numpy-int
try:
action = action[0]
except IndexError as e:
print(type(action))
raise e
if self.num_steps:
msg = ",".join([str(action), str(self.num_steps)])
else:
msg = str(action)
self.send_msg(str.encode(msg))
s, r, terminated = self._process_data()
r = max(self.reward_range[0], min(self.reward_range[1], r))
info = {}
if terminated:
self.done = True
self.kill_connection()
if self.c_step > self.n_steps:
info["needs_reset"] = True
self.send_msg(str.encode("END"))
self.kill_connection()
return s, r, terminated, self.done, info
[docs]
def reset(self, seed=None, options=None):
"""Reset environment.
Returns:
----------
np.array
State after reset
dict
Meta-info
"""
if options is None:
options = {}
super().reset_(seed)
self._prev_state = None
self.__start_time = time.time()
if not self.done: # This means we interrupt FD before a plan was found
# Inform FD about imminent shutdown of the connection
self.send_msg(str.encode("END"))
self.done = False
if self.conn:
self.conn.shutdown(2)
self.conn.close()
self.conn = None
if not self.socket:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.settimeout(60)
self.socket.bind((self.host, self.port))
if self.fd:
self.fd.terminate()
if self.instance.parts[-1].endswith(".pddl"):
command = [
"python3",
f"{self.fd_path}",
self.domain_file,
self.instance,
"--search",
self._argstring,
]
else:
command = [
"python3",
f"{self.fd_path}",
self.instance,
"--search",
self._argstring,
]
with open(self.logpath_out, "a+") as fout, open(self.logpath_err, "a+") as ferr:
err_output = subprocess.STDOUT if self.logpath_err == "/dev/null" else ferr
self.fd = subprocess.Popen(command, stdout=fout, stderr=err_output)
# write down port such that FD can potentially read where to connect to
if self._port_file_id:
fp = Path(self._config_dir) / f"port_{self._port_file_id:d}.txt"
else:
fp = Path(self._config_dir) / f"port_{self.port}.txt"
with open(fp, "w") as portfh:
portfh.write(str(self.port))
self.socket.listen()
try:
self.conn, address = self.socket.accept()
except TimeoutError:
raise OSError( # noqa: B904
"Fast downward subprocess not reachable (time out). "
"Possible solutions:\n"
" (1) Did you run './dacbench/envs/rl-plan/fast-downward/build.py' "
"in order to build the fd backend?\n"
" (2) Try to fix this by setting OPENBLAS_NUM_THREADS=1. "
"For more details see https://github.com/automl/DACBench/issues/96"
)
s, _, _ = self._process_data()
if self.max_rand_steps > 1:
for _ in range(self.np_random.randint(1, self.max_rand_steps + 1)):
s, _, _, _, _ = self.step(self.action_space.sample())
if self.conn is None:
return self.reset()
else:
s, _, _, _, _ = self.step(0) # hard coded to zero as initial step
Path.unlink(
fp
) # remove the port file such that there is no chance of loading the old port
return s, {}
[docs]
def kill_connection(self):
"""Kill the connection."""
if self.conn:
self.conn.shutdown(2)
self.conn.close()
self.conn = None
if self.socket:
self.socket.shutdown(2)
self.socket.close()
self.socket = None
[docs]
def close(self):
"""Close Env.
Returns:
-------
bool
Closing confirmation
"""
if self.socket is None:
return True
fp = Path(self._config_dir) / f"port_{self.port}.txt"
if Path.exists(fp):
Path.unlink(fp)
self.kill_connection()
return True
[docs]
def render(self, mode: str = "human") -> None:
"""Required by gym.Env but not implemented.
Parameters
-------
mode : str
Rendering mode
"""