Taxi Sim Q-Learning

mail@pastecode.io avatar
unknown
python
2 years ago
1.7 kB
1
Indexable
Never
import numpy as np
import gym
import time

NEPOCHS = 10000
MAX_ITER = 1000
EPSILON = 0.1
DISCOUNT = 0.6
LEARNING_RATE = 0.1

def train(env, nepochs=NEPOCHS, max_iter=MAX_ITER, epsilon=EPSILON, learning_rate=LEARNING_RATE, discount=DISCOUNT):
    
    qtable = np.zeros((env.observation_space.n, env.action_space.n))
    
    for i in range(nepochs):
        obs = env.reset()
        
        for j in range(max_iter):
            
            if np.random.uniform(0, 1) < epsilon:
                action = np.random.choice(env.action_space.n)
            else:
                action = np.argmax(qtable[obs])
                
            next_obs, reward, done, _ = env.step(action)
            if done:
                break
            
            qtable[obs, action] = qtable[obs, action] + learning_rate * (reward + discount * np.max(qtable[next_obs]) - qtable[obs, action])
            obs = next_obs
            
    return qtable
    
def execute(env, qtable):
    count = 0
    obs = env.reset()
    env.render()
    done = False
    while not done:
        action = np.argmax(qtable[obs])
        obs, reward, done, _ = env.step(action)
        env.render()
        count += 1
        
    return count
    
def executeTwo(env, qtable):
    count = 0
    obs = env.reset()
    env.render()
    done = False
    while not done:
        action = np.argmax(qtable[obs])
        obs, reward, done, _ = env.step(action)
        print('\x1b[1J' + env.render(mode='ansi'))
        time.sleep(0.5)
        count += 1
    
    return count


env = gym.make('Taxi-v3')
qtable = train(env)
result = execute(env, qtable)
## result = executeTwo(env, qtable)
print(result)