DQN Algorithm
unknown
python
a year ago
4.8 kB
9
Indexable
Never
class DeepQNetwork(): def __init__( self, n_actions, input_shape, qnet, device, learning_rate = 2e-4, reward_decay = 0.99, replace_target_iter = 1000, memory_size = 10000, batch_size = 32, ): # initialize parameters self.n_actions = n_actions self.input_shape = input_shape self.lr = learning_rate self.gamma = reward_decay self.replace_target_iter = replace_target_iter self.memory_size = memory_size self.batch_size = batch_size self.device = device self.learn_step_counter = 0 self.init_memory() # Network self.qnet_eval = qnet(self.input_shape, self.n_actions).to(self.device) self.qnet_target = qnet(self.input_shape, self.n_actions).to(self.device) self.qnet_target.eval() self.optimizer = optim.RMSprop(self.qnet_eval.parameters(), lr=self.lr) def choose_action(self, state, epsilon=0): state = torch.FloatTensor(state).unsqueeze(0).to(self.device) actions_value = self.qnet_eval.forward(state) if np.random.uniform() > epsilon: # greedy action = torch.max(actions_value, 1)[1].data.cpu().numpy()[0] else: # random action = np.random.randint(0, self.n_actions) return action def learn(self): # TODO(Lab-5): DQN core algorithm. # check to replace target para. if self.learn_step_counter % self.replace_targer_iter == 0: self.qnet_target.load_state_dict(self.qnet_eval.state_dict()) # sample batch memory from all memory if self.memory_counter > self.memory_size: sample_index = np.random.choice(self.memory_size, size=self.batch_size) else: sample_index = np.random.choice(self.memory_counter, size=self.batch.batch_size) b_s = torch.FloatTensor(self.memory["s"][sample_index]).to(self.device) b_a = torch.LongTensor(self.memory["a"][sample_index]).to(self.device) b_e = torch.FloatTensor(self.memory["r"][sample_index]).to(self.device) b_s_ = torch.FloatTensor(self.memory["s_"][sample_index]).to(self.device) b_d = torch.FloatTensor(self.memory["done"][sample_index]).to(self.device) # TD-error q_curr_eval = self.qnet_eval(b_s).gather(1, b_a) q_next_target = self.qnet_target(b_s_).detach() next_state_values = q_next_target.max(1)[0].view(-1, 1) q_curr_recur = b_r + (1-b_d) * self.gamma * next_state_values q_curr_eval = self.qnet_eval(b_s).gather(1, b_a) q_next_target = self.qnet_target(b_s_).detach() q_next_eval = self.qnet_eval(b_s_).detach() next_state_values = q_next_target.gather(1, q_next_eval.max(1)[1].unsqueeze(1)) q_curr_recur = b_r + (1-b_d) * self.gamma * next_state_values # Compute the loss and optimize the network self.loss = F.smooth_l1_loss(q_curr_eval, q_curr_recur) self.optimizer.zero_grad() self.loss.backward() self.optimizer.step() self.learn_step_counter += 1 return self.loss.detach().cpu().numpy() # replayed buffer def init_memory(self): self.memory = { "s": np.zeros((self.memory_size, *self.input_shape)), "a": np.zeros((self.memory_size, 1)), "r": np.zeros((self.memory_size, 1)), "s_": np.zeros((self.memory_size, *self.input_shape)), "done": np.zeros((self.memory_size, 1)), } def store_transition(self, s, a, r, s_, d): if not hasattr(self, 'memory_counter'): self.memory_counter = 0 if self.memory_counter <= self.memory_size: index = self.memory_counter % self.memory_size else: index = np.random.randint(self.memory_size) self.memory["s"][index] = s self.memory["a"][index] = np.array(a).reshape(-1,1) self.memory["r"][index] = np.array(r).reshape(-1,1) self.memory["s_"][index] = s_ self.memory["done"][index] = np.array(d).reshape(-1,1) self.memory_counter += 1 def save_load_model(self, op, path="save", fname="qnet.pt"): import os if not os.path.exists(path): os.makedirs(path) file_path = os.path.join(path, fname) if op == "save": torch.save(self.qnet_eval.state_dict(), file_path) elif op == "load": self.qnet_eval.load_state_dict(torch.load(file_path, map_location=self.device)) self.qnet_target.load_state_dict(torch.load(file_path, map_location=self.device))