import warnings
import numpy as np
import gym
from gym.spaces import Box, Space
import PIL.ImageDraw as ImageDraw
import PIL.Image as Image
from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM
import os
[docs]class ImageContinuous(Box):
    '''A space that maps a continuous 1- or 2-D space 1-to-1 to images so that the
    images may be used as representations for corresponding continuous environments.
    Methods
    -------
    get_concatenated_image(continuous_obs)
        Gets an image representation for a given feature space observation
    '''
[docs]    def __init__(self, feature_space, term_spaces=None, width=100, height=100,\
                
circle_radius=5, target_point=None, relevant_indices=[0,1],\
                
seed=None, grid_shape=None, dtype=np.uint8):
        '''
        Parameters
        ----------
        feature_space : Gym.spaces.Box
            The feature space to which this class associates images as external
            observations
        term_spaces : list of Gym.spaces.Box
            Sub-spaces of the feature space which are terminal
        width : int
            The width of the image
        height : int
            The height of the image
        circle_radius : int
            The radius of the circle which represents the agent and target point
        target_point : np.array
        relevant_indices : list
        grid_shape : tuple of length 2
        seed : int
            Seed for this space
        '''
        # ##TODO Define a common superclass for this and ImageMultiDiscrete
        self.feature_space = feature_space
        assert (self.feature_space.high != np.inf).any()
        assert (self.feature_space.low != -np.inf).any()
        self.width = width
        self.height = height
        # Warn if resolution is too low?
        self.circle_radius = circle_radius
        self.target_point = target_point
        self.term_spaces = term_spaces
        self.relevant_indices = relevant_indices
        all_indices = set(range(self.feature_space.shape[0]))
        self.irrelevant_indices = list(all_indices - set(self.relevant_indices))
        if len(self.irrelevant_indices) == 0:
            self.irrelevant_features = False
        else:
            self.irrelevant_features = True
        if grid_shape is not None:
            self.draw_grid = True
            assert type(grid_shape == tuple) and \
                        
(len(grid_shape) == 2 or len(grid_shape) == 4)
            # Could also assert that self.width is divisible by grid_shape[0], etc.
            self.grid_shape = grid_shape
        else:
            self.draw_grid = False
        self.goal_colour = (0, 255, 0)
        self.agent_colour = (0, 0, 255)
        self.term_colour = (255, 0, 0)
        self.bg_colour = (0, 0, 0)
        self.line_colour = (255, 255, 255)
        assert len(feature_space.shape) == 1
        relevant_dims = len(relevant_indices)
        irr_dims = len(self.irrelevant_indices)
        assert relevant_dims <= 2 and irr_dims <=2, "Image observations are "\
                
"supported only "\
                
"for 1- or 2-D feature spaces."
        # Shape has 1 appended for Ray Rllib to be compatible IIRC
        super(ImageContinuous, self).__init__(shape=(width, height, 1), \
                
dtype=dtype, low=0, high=255)
        super(ImageContinuous, self).seed(seed=seed)
        if self.target_point is not None:
            if self.draw_grid:
                target_point += 0.5
            self.target_point_pixel = self.convert_to_pixel(target_point) 
[docs]    def generate_image(self, position, relevant=True):
        '''
        Parameters
        ----------
        position : np.array
        '''
        # Use RGB
        image_ = Image.new("RGB", (self.width, self.height), color=self.bg_colour)
        # Use L for black and white 8-bit pixels instead of RGB in case not
        # using custom images
        # image_ = Image.new("L", (self.width, self.height))
        draw = ImageDraw.Draw(image_)
        # Draw in decreasing order of importance:
        # grid lines, term_spaces, etc. first, so that others are drawn over them
        if self.draw_grid:
            position = position.astype(float)
            position += 0.5
            offset = 0 if relevant else 2
            for i in range(1, self.grid_shape[0 + offset] + 1): # +1 because this is along
            # concatentation dimension when stitching together images below in
            # get_concatenated_image
                x_ = i * self.width // self.grid_shape[0 + offset] - 1 # -1 to not go outside
                # image size for the last line drawn
                y_ = self.height
                start_pt = (x_, y_)
                y_ = 0
                end_pt = (x_, y_)
                draw.line([start_pt, end_pt], fill=self.line_colour)
            for j in range(1, self.grid_shape[1 + offset]):
                x_ = self.width
                y_ = j * self.height // self.grid_shape[0 + offset]
                start_pt = (x_, y_)
                x_ = 0
                end_pt = (x_, y_)
                draw.line([start_pt, end_pt], fill=self.line_colour)
        if self.term_spaces is not None and relevant:
            for term_space in self.term_spaces:
                low = self.convert_to_pixel(term_space.low)
                if self.draw_grid:
                    high = self.convert_to_pixel(term_space.high + 1.)
                else:
                    high = self.convert_to_pixel(term_space.high)
                leftUpPoint = tuple((low))
                rightDownPoint = tuple((high))
                twoPointList = [leftUpPoint, rightDownPoint]
                draw.rectangle(twoPointList, fill=self.term_colour)
        R = self.circle_radius
        if self.target_point is not None and relevant:
            # print("draw2", self.target_point_pixel)
            leftUpPoint = tuple((self.target_point_pixel - R))
            rightDownPoint = tuple((self.target_point_pixel + R))
            twoPointList = [leftUpPoint, rightDownPoint]
            draw.ellipse(twoPointList, fill=self.goal_colour)
        pos_pixel = self.convert_to_pixel(position)
        # print("draw1", pos_pixel)
        # Draw circle https://stackoverflow.com/a/2980931/11063709
        leftUpPoint = tuple(pos_pixel - R)
        rightDownPoint = tuple(pos_pixel + R)
        twoPointList = [leftUpPoint, rightDownPoint]
        draw.ellipse(twoPointList, fill=self.agent_colour)
        # Because numpy is row-major and Image is column major, need to transpose
        # ret_arr = np.array(image_).T # For 2-D
        ret_arr = np.transpose(np.array(image_), axes=(1, 0, 2))
        return ret_arr 
[docs]    def get_concatenated_image(self, obs):
        '''Gets the "stitched together" image made from images corresponding to
        each continuous sub-space within the continuous space, concatenated
        along the X-axis.
        '''
        concatenated_image = []
        # For relevant/irrelevant sub-spaces:
        concatenated_image.append(self.generate_image(obs[self.relevant_indices]))
        if self.irrelevant_features:
            irr_image = self.generate_image(obs[self.irrelevant_indices], relevant=False)
            concatenated_image.append(irr_image)
        concatenated_image = np.concatenate(tuple(concatenated_image), axis=0)
        return np.atleast_3d(concatenated_image) # because Ray expects an 
        # image to have >=3 dims
[docs]    def convert_to_pixel(self, position):
        '''
        '''
        # It's implicit that both relevant and irrelevant sub-spaces have the
        # same max and min here:
        max = self.feature_space.high[self.relevant_indices]
        min = self.feature_space.low[self.relevant_indices]
        pos_pixel = ((position - min) / (max - min))
        pos_pixel = (pos_pixel * self.shape[:2]).astype(int)
        return pos_pixel 
[docs]    def sample(self):
        sampled = self.feature_space.sample()
        return self.get_concatenated_image(sampled) 
    def __repr__(self):
        return "{} with continuous underlying space of shape: {} and "\
                
"images of resolution: {} and dtype: {}".format(self.__class__,\
                
self.feature_space.shape,\
                
self.shape, self.dtype)
[docs]    def contains(self, x):
        """
        Return boolean specifying if x is a valid
        member of this space
        """
        if x.shape == (self.width, self.height, 1): #TODO compare each pixel for all possible images?
            return True 
[docs]    def to_jsonable(self, sample_n):
        """Convert a batch of samples from this space to a JSONable data type."""
        # By default, assume identity is JSONable
        raise NotImplementedError 
[docs]    def from_jsonable(self, sample_n):
        """Convert a JSONable data type to a batch of samples from this space."""
        # By default, assume identity is JSONable
        raise NotImplementedError 
    def __eq__(self, other):
        raise NotImplementedError