Source code for mdp_playground.spaces.image_continuous

import warnings
import numpy as np
import gym
from gym.spaces import Box, Space
import PIL.ImageDraw as ImageDraw
import PIL.Image as Image
from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM
import os

[docs]class ImageContinuous(Box): '''A space that maps a continuous 1- or 2-D space 1-to-1 to images so that the images may be used as representations for corresponding continuous environments. Methods ------- get_concatenated_image(continuous_obs) Gets an image representation for a given feature space observation '''
[docs] def __init__(self, feature_space, term_spaces=None, width=100, height=100,\ circle_radius=5, target_point=None, relevant_indices=[0,1],\ seed=None, grid_shape=None, dtype=np.uint8): ''' Parameters ---------- feature_space : Gym.spaces.Box The feature space to which this class associates images as external observations term_spaces : list of Gym.spaces.Box Sub-spaces of the feature space which are terminal width : int The width of the image height : int The height of the image circle_radius : int The radius of the circle which represents the agent and target point target_point : np.array relevant_indices : list grid_shape : tuple of length 2 seed : int Seed for this space ''' # ##TODO Define a common superclass for this and ImageMultiDiscrete self.feature_space = feature_space assert (self.feature_space.high != np.inf).any() assert (self.feature_space.low != -np.inf).any() self.width = width self.height = height # Warn if resolution is too low? self.circle_radius = circle_radius self.target_point = target_point self.term_spaces = term_spaces self.relevant_indices = relevant_indices all_indices = set(range(self.feature_space.shape[0])) self.irrelevant_indices = list(all_indices - set(self.relevant_indices)) if len(self.irrelevant_indices) == 0: self.irrelevant_features = False else: self.irrelevant_features = True if grid_shape is not None: self.draw_grid = True assert type(grid_shape == tuple) and \ (len(grid_shape) == 2 or len(grid_shape) == 4) # Could also assert that self.width is divisible by grid_shape[0], etc. self.grid_shape = grid_shape else: self.draw_grid = False self.goal_colour = (0, 255, 0) self.agent_colour = (0, 0, 255) self.term_colour = (255, 0, 0) self.bg_colour = (0, 0, 0) self.line_colour = (255, 255, 255) assert len(feature_space.shape) == 1 relevant_dims = len(relevant_indices) irr_dims = len(self.irrelevant_indices) assert relevant_dims <= 2 and irr_dims <=2, "Image observations are "\ "supported only "\ "for 1- or 2-D feature spaces." # Shape has 1 appended for Ray Rllib to be compatible IIRC super(ImageContinuous, self).__init__(shape=(width, height, 1), \ dtype=dtype, low=0, high=255) super(ImageContinuous, self).seed(seed=seed) if self.target_point is not None: if self.draw_grid: target_point += 0.5 self.target_point_pixel = self.convert_to_pixel(target_point)
[docs] def generate_image(self, position, relevant=True): ''' Parameters ---------- position : np.array ''' # Use RGB image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour) # Use L for black and white 8-bit pixels instead of RGB in case not # using custom images # image_ = Image.new("L", (self.width, self.height)) draw = ImageDraw.Draw(image_) # Draw in decreasing order of importance: # grid lines, term_spaces, etc. first, so that others are drawn over them if self.draw_grid: position = position.astype(float) position += 0.5 offset = 0 if relevant else 2 for i in range(1, self.grid_shape[0 + offset] + 1): # +1 because this is along # concatentation dimension when stitching together images below in # get_concatenated_image x_ = i * self.width // self.grid_shape[0 + offset] - 1 # -1 to not go outside # image size for the last line drawn y_ = self.height start_pt = (x_, y_) y_ = 0 end_pt = (x_, y_) draw.line([start_pt, end_pt], fill=self.line_colour) for j in range(1, self.grid_shape[1 + offset]): x_ = self.width y_ = j * self.height // self.grid_shape[0 + offset] start_pt = (x_, y_) x_ = 0 end_pt = (x_, y_) draw.line([start_pt, end_pt], fill=self.line_colour) if self.term_spaces is not None and relevant: for term_space in self.term_spaces: low = self.convert_to_pixel(term_space.low) if self.draw_grid: high = self.convert_to_pixel(term_space.high + 1.) else: high = self.convert_to_pixel(term_space.high) leftUpPoint = tuple((low)) rightDownPoint = tuple((high)) twoPointList = [leftUpPoint, rightDownPoint] draw.rectangle(twoPointList, fill=self.term_colour) R = self.circle_radius if self.target_point is not None and relevant: # print("draw2", self.target_point_pixel) leftUpPoint = tuple((self.target_point_pixel - R)) rightDownPoint = tuple((self.target_point_pixel + R)) twoPointList = [leftUpPoint, rightDownPoint] draw.ellipse(twoPointList, fill=self.goal_colour) pos_pixel = self.convert_to_pixel(position) # print("draw1", pos_pixel) # Draw circle https://stackoverflow.com/a/2980931/11063709 leftUpPoint = tuple(pos_pixel - R) rightDownPoint = tuple(pos_pixel + R) twoPointList = [leftUpPoint, rightDownPoint] draw.ellipse(twoPointList, fill=self.agent_colour) # Because numpy is row-major and Image is column major, need to transpose # ret_arr = np.array(image_).T # For 2-D ret_arr = np.transpose(np.array(image_), axes=(1, 0, 2)) return ret_arr
[docs] def get_concatenated_image(self, obs): '''Gets the "stitched together" image made from images corresponding to each continuous sub-space within the continuous space, concatenated along the X-axis. ''' concatenated_image = [] # For relevant/irrelevant sub-spaces: concatenated_image.append(self.generate_image(obs[self.relevant_indices])) if self.irrelevant_features: irr_image = self.generate_image(obs[self.irrelevant_indices], relevant=False) concatenated_image.append(irr_image) concatenated_image = np.concatenate(tuple(concatenated_image), axis=0) return np.atleast_3d(concatenated_image) # because Ray expects an
# image to have >=3 dims
[docs] def convert_to_pixel(self, position): ''' ''' # It's implicit that both relevant and irrelevant sub-spaces have the # same max and min here: max = self.feature_space.high[self.relevant_indices] min = self.feature_space.low[self.relevant_indices] pos_pixel = ((position - min) / (max - min)) pos_pixel = (pos_pixel * self.shape[:2]).astype(int) return pos_pixel
[docs] def sample(self): sampled = self.feature_space.sample() return self.get_concatenated_image(sampled)
def __repr__(self): return "{} with continuous underlying space of shape: {} and "\ "images of resolution: {} and dtype: {}".format(self.__class__,\ self.feature_space.shape,\ self.shape, self.dtype)
[docs] def contains(self, x): """ Return boolean specifying if x is a valid member of this space """ if x.shape == (self.width, self.height, 1): #TODO compare each pixel for all possible images? return True
[docs] def to_jsonable(self, sample_n): """Convert a batch of samples from this space to a JSONable data type.""" # By default, assume identity is JSONable raise NotImplementedError
[docs] def from_jsonable(self, sample_n): """Convert a JSONable data type to a batch of samples from this space.""" # By default, assume identity is JSONable raise NotImplementedError
def __eq__(self, other): raise NotImplementedError