Untitled
unknown
plain_text
2 years ago
2.8 kB
5
Indexable
# create Taxi environment env = gym.make('Taxi-v3', new_step_api=True) # initialize q-table state_size = env.observation_space.n print("estados",env.observation_space.n) action_size = env.action_space.n print("acciones", env.action_space.n) qtable = np.zeros((state_size, action_size)) actions_dict = { 0: "South", 1: "North", 2: "East", 3: "West", 4: "Pick-up", 5: "Drop-off", } # hyperparameters to be defined num_episodes = 600 learning_rate = 0.5 # alpha epsilon = 0.75 discount_rate = 0.99 # gamma decay_rate= 0.01 # factor que decrementa epsilon en cada episodio epsilon = np.exp(-decay_rate*episode) max_steps = 99 # per episode # training for episode in range(num_episodes): # reset the environment state = env.reset() done = False for s in range(max_steps): # exploration-exploitation tradeoff if random.uniform(0,1) < epsilon: # explore action = env.action_space.sample() else: # exploit action = np.argmax(qtable[state,:]) # take action and observe reward new_state, reward, terminated, truncated, info = env.step(action) # Q-learning algorithm qtable[state,action] = qtable[state,action] + learning_rate * (reward + discount_rate * np.max(qtable[new_state,:])-qtable[state,action]) # Update to our new state state = new_state # if done, finish episode if done == True: break # Decrease epsilon epsilon = np.exp(-decay_rate*episode) print(f"Training completed over {num_episodes} episodes") input("Press Enter to watch trained agent...") # watch trained agent state = env.reset() done = False rewards = 0 imgs = [] fig, axs = plt.subplots(1, 2, figsize=(10, 4)) for s in range(max_steps): #print(f"TRAINED AGENT") #print("Step {}".format(s+1)) action = np.argmax(qtable[state,:]) new_state, reward, terminated, truncated, info = env.step(action) rewards += reward done = terminated or truncated # Render the env env_screen = env.render(mode = 'rgb_array') print("step:", s,f"score: {rewards}") state = new_state print("estado",state, "action", action, "done", done) #print("valor q max ",qtable[state,action]) plt.axis('off') im = axs[0].imshow(env_screen) im2 = axs[1].text(0.5, 0.5, "Step:\n"+str(s+1)+"\n\nScore:\n"+str(rewards)+"\n\nState:\n"+str(new_state)+"\n\nAction:\n"+str(actions_dict[action])+"\n\nDone:\n"+str(done)+"\n\nMax Qvalue:\n"+"{:.2f}".format(qtable[state,action]), ha='center', va='center') imgs.append([im, im2]) if done == True: #to avoid cutting the last frame of the video and show final state imgs.append([im, im2]) imgs.append([im, im2]) imgs.append([im, im2]) imgs.append([im, im2]) imgs.append([im, im2]) break
Editor is loading...
Leave a Comment