Untitled
unknown
plain_text
2 years ago
2.8 kB
12
Indexable
# create Taxi environment
env = gym.make('Taxi-v3', new_step_api=True)
# initialize q-table
state_size = env.observation_space.n
print("estados",env.observation_space.n)
action_size = env.action_space.n
print("acciones", env.action_space.n)
qtable = np.zeros((state_size, action_size))
actions_dict = {
0: "South",
1: "North",
2: "East",
3: "West",
4: "Pick-up",
5: "Drop-off",
}
# hyperparameters to be defined
num_episodes = 600
learning_rate = 0.5 # alpha
epsilon = 0.75
discount_rate = 0.99 # gamma
decay_rate= 0.01 # factor que decrementa epsilon en cada episodio epsilon = np.exp(-decay_rate*episode)
max_steps = 99 # per episode
# training
for episode in range(num_episodes):
# reset the environment
state = env.reset()
done = False
for s in range(max_steps):
# exploration-exploitation tradeoff
if random.uniform(0,1) < epsilon:
# explore
action = env.action_space.sample()
else:
# exploit
action = np.argmax(qtable[state,:])
# take action and observe reward
new_state, reward, terminated, truncated, info = env.step(action)
# Q-learning algorithm
qtable[state,action] = qtable[state,action] + learning_rate * (reward + discount_rate * np.max(qtable[new_state,:])-qtable[state,action])
# Update to our new state
state = new_state
# if done, finish episode
if done == True:
break
# Decrease epsilon
epsilon = np.exp(-decay_rate*episode)
print(f"Training completed over {num_episodes} episodes")
input("Press Enter to watch trained agent...")
# watch trained agent
state = env.reset()
done = False
rewards = 0
imgs = []
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
for s in range(max_steps):
#print(f"TRAINED AGENT")
#print("Step {}".format(s+1))
action = np.argmax(qtable[state,:])
new_state, reward, terminated, truncated, info = env.step(action)
rewards += reward
done = terminated or truncated
# Render the env
env_screen = env.render(mode = 'rgb_array')
print("step:", s,f"score: {rewards}")
state = new_state
print("estado",state, "action", action, "done", done)
#print("valor q max ",qtable[state,action])
plt.axis('off')
im = axs[0].imshow(env_screen)
im2 = axs[1].text(0.5, 0.5, "Step:\n"+str(s+1)+"\n\nScore:\n"+str(rewards)+"\n\nState:\n"+str(new_state)+"\n\nAction:\n"+str(actions_dict[action])+"\n\nDone:\n"+str(done)+"\n\nMax Qvalue:\n"+"{:.2f}".format(qtable[state,action]), ha='center', va='center')
imgs.append([im, im2])
if done == True:
#to avoid cutting the last frame of the video and show final state
imgs.append([im, im2])
imgs.append([im, im2])
imgs.append([im, im2])
imgs.append([im, im2])
imgs.append([im, im2])
breakEditor is loading...
Leave a Comment