import unittest
import numpy as np
from mdp_playground.spaces.image_continuous import ImageContinuous
from gym.spaces import Box
# import PIL.ImageDraw as ImageDraw
import PIL.Image as Image
# import PIL.ImageOps as ImageOps
[docs]class TestImageContinuous(unittest.TestCase):
    def test_image_continuous(self):
        lows = 0.0
        highs = 20.0
        cs2 = Box(shape=(2,), low=lows, high=highs,)
        cs4 = Box(shape=(4,), low=lows, high=highs,)
        imc = ImageContinuous(cs2, width=100, height=100,)
        pos = np.array([5.0, 7.0])
        img1 = Image.fromarray(np.squeeze(imc.generate_image(pos)), 'RGB')
        # img1 = ImageOps.invert(img1)
        img1.show()
        # img1.save("cont_state_no_target.pdf")
        target = np.array([10, 10])
        imc = ImageContinuous(cs2, target_point=target, width=100, height=100,)
        img1 = Image.fromarray(np.squeeze(imc.generate_image(pos)), 'RGB')
        img1.show()
        # img1.save("cont_state_target.pdf")
        # Terminal sub-spaces
        lows = np.array([2., 4.])
        highs = np.array([3., 6.])
        cs2_term1 = Box(low=lows, high=highs,)
        lows = np.array([12., 3.])
        highs = np.array([13., 4.])
        cs2_term2 = Box(low=lows, high=highs,)
        term_spaces = [cs2_term1, cs2_term2]
        target = np.array([10, 10])
        imc = ImageContinuous(cs2, target_point=target, term_spaces=term_spaces,\
                        
width=100, height=100,)
        pos = np.array([5.0, 7.0])
        img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), 'RGB')
        img1.show()
        # img1.save("cont_state_target_terminal_states.pdf")
        # Irrelevant features
        target = np.array([10, 10])
        imc = ImageContinuous(cs4, target_point=target, width=400, height=400,)
        pos = np.array([5.0, 7.0, 10.0, 15.0])
        img1 = Image.fromarray(np.squeeze(imc.get_concatenated_image(pos)), 'RGB')
        img1.show()
        # print(imc.get_concatenated_image(pos).shape)
        # Random sample and __repr__
        imc = ImageContinuous(cs4, target_point=target, width=400, height=400,)
        # print(imc)
        img1 = Image.fromarray(np.squeeze(imc.sample()), 'RGB')
        img1.show()
        # Draw grid
        imc = ImageContinuous(cs4, target_point=target, width=400, height=400,
                                grid=(5,5))
        img1 = Image.fromarray(np.squeeze(imc.sample()), 'RGB')
        img1.show() 
if __name__ == '__main__':
    unittest.main()