import numpy as np
import gym
import random
import time

import imageio
from PIL import Image
import PIL.ImageDraw as ImageDraw
import matplotlib.pyplot as plt


def _label_with_episode_number(frame, episode_num):
    im = Image.fromarray(frame)

    drawer = ImageDraw.Draw(im)

    if np.mean(im) < 128:
        text_color = (255,255,255)
    else:
        text_color = (0,0,0)
    drawer.text((im.size[0]/20,im.size[1]/18), f'Episode: {episode_num+1}', fill=text_color)

    return im


env = gym.make('MsPacman-v0')


action_space_size = env.action_space.n
#Create Q table for reinforcement learning
qTable = np.zeros((229, 6))
temp = np.zeros(6)

num_episodes = 400
max_steps_per_episode = 1000

#a and g parameters for q-learning algorithm
learning_rate = 0.1  
discount_rate = 0.9  

explor_rate = 1
max_explor = 1
min_explor = 0.01
decay = 0.01

frames=[]
rewards_all_episodes = []
divider = 20
sumOfRewards = 0
q_table_choosed = 0


def max_action_given_state(state):
    for i in range(6):
        temp[i] = qTable[state, i][0][0][0]
    max_v = -1
    for j in range(6):
        if temp[j] > max_v:
            max_v = temp[j]
            max_p = j
    return max_p


# Q-Learning algorithm
for episode in range(num_episodes):
    state = env.reset()

    gameOver = False
    currentRewards = 0
    q_table_choosed = 0

    for step in range(max_steps_per_episode):
        
        enviroment=env.render()
        time.sleep(0.02)
        exploration_rate_threshold = random.uniform(0, 1)  #Exploration trade-off
        if exploration_rate_threshold > explor_rate:
            q_table_choosed += 1
            action = max_action_given_state(state)
        else:
            action = env.action_space.sample()
            while action != 0 and action != 2 and action != 3 and action != 4 and action != 5:
                action = env.action_space.sample()
        nextState, reward, gameOver, info = env.step(action)

        # Update Q-Table for Q(s,a)
        qTable[state, action] = qTable[state, action] * (1 - learning_rate) + \
            learning_rate * (reward + discount_rate * np.max(qTable[nextState, :]))

        state = nextState
        currentRewards += reward
        sumOfRewards += reward

        if gameOver:
            break

    #Exploration rate decay
    explor_rate = min_explor + \
        (max_explor - min_explor) * np.exp(-decay * episode)

    if episode % divider == 0:
        if episode != 0:
            avg_reward = sumOfRewards / divider
            print("Episodes: ", episode - divider, "-", episode)
            print("Avg reward : ", avg_reward, "\n")
            sumOfRewards = 0
    percent = (q_table_choosed / step) * 100
    print("  REWARDS:  ", currentRewards, " Q-Table actions percentage: ", percent,"%")

for episode in range(5):
    state = env.reset()
    gameOver = False
    print("Episode num: ", episode + 1, "  \n\n\n\n")
    time.sleep(1)

    for step in range(max_steps_per_episode):
        env.render()
        #Store frames for creating video
        frames.append(_label_with_episode_number(enviroment,episode))

        time.sleep(0.1)

        nextState, reward, gameOver, info = env.step(max_action_given_state(state))

        if gameOver:
            break

        state = nextState
env.close()

imageio.mimwrite('tempVideo'+'.mp4', frames, fps=24, macro_block_size=None)