Untitled

mail@pastecode.io avatar
unknown
plain_text
7 months ago
2.8 kB
3
Indexable
Never
# create Taxi environment
env = gym.make('Taxi-v3', new_step_api=True)

 # initialize q-table
state_size = env.observation_space.n
print("estados",env.observation_space.n)
action_size = env.action_space.n
print("acciones", env.action_space.n)
qtable = np.zeros((state_size, action_size))

actions_dict = {
 0: "South",
 1: "North",
 2: "East",
 3: "West",
 4: "Pick-up",
 5: "Drop-off",
}
# hyperparameters to be defined
num_episodes = 600
learning_rate = 0.5 # alpha
epsilon = 0.75
discount_rate = 0.99 # gamma
decay_rate= 0.01   # factor que decrementa epsilon en cada episodio  epsilon = np.exp(-decay_rate*episode)


max_steps = 99 # per episode


# training
for episode in range(num_episodes):

  # reset the environment
  state = env.reset()
  done = False

  for s in range(max_steps):

    # exploration-exploitation tradeoff
    if random.uniform(0,1) < epsilon:
      # explore
      action = env.action_space.sample()
    else:
      # exploit
      action = np.argmax(qtable[state,:])

    # take action and observe reward
    new_state, reward, terminated, truncated, info = env.step(action)


    # Q-learning algorithm
    qtable[state,action] = qtable[state,action] + learning_rate * (reward + discount_rate * np.max(qtable[new_state,:])-qtable[state,action])

    # Update to our new state
    state = new_state

    # if done, finish episode
    if done == True:
        break

  # Decrease epsilon
  epsilon = np.exp(-decay_rate*episode)

print(f"Training completed over {num_episodes} episodes")
input("Press Enter to watch trained agent...")

# watch trained agent
state = env.reset()
done = False
rewards = 0

imgs = []
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
for s in range(max_steps):

  #print(f"TRAINED AGENT")
  #print("Step {}".format(s+1))

  action = np.argmax(qtable[state,:])
  new_state, reward, terminated, truncated, info = env.step(action)
  rewards += reward
  done = terminated or truncated

  # Render the env
  env_screen = env.render(mode = 'rgb_array')

  print("step:", s,f"score: {rewards}")
  state = new_state
  print("estado",state, "action", action, "done", done)
  #print("valor q max ",qtable[state,action])

  plt.axis('off')

  im = axs[0].imshow(env_screen)
  im2 = axs[1].text(0.5, 0.5, "Step:\n"+str(s+1)+"\n\nScore:\n"+str(rewards)+"\n\nState:\n"+str(new_state)+"\n\nAction:\n"+str(actions_dict[action])+"\n\nDone:\n"+str(done)+"\n\nMax Qvalue:\n"+"{:.2f}".format(qtable[state,action]), ha='center', va='center')
  imgs.append([im, im2])


  if done == True:
    #to avoid cutting the last frame of the video and show final state
    imgs.append([im, im2])
    imgs.append([im, im2])
    imgs.append([im, im2])
    imgs.append([im, im2])
    imgs.append([im, im2])
    break
Leave a Comment