DQN Algorithm
unknown
python
2 years ago
4.8 kB
15
Indexable
class DeepQNetwork():
def __init__(
self,
n_actions,
input_shape,
qnet,
device,
learning_rate = 2e-4,
reward_decay = 0.99,
replace_target_iter = 1000,
memory_size = 10000,
batch_size = 32,
):
# initialize parameters
self.n_actions = n_actions
self.input_shape = input_shape
self.lr = learning_rate
self.gamma = reward_decay
self.replace_target_iter = replace_target_iter
self.memory_size = memory_size
self.batch_size = batch_size
self.device = device
self.learn_step_counter = 0
self.init_memory()
# Network
self.qnet_eval = qnet(self.input_shape, self.n_actions).to(self.device)
self.qnet_target = qnet(self.input_shape, self.n_actions).to(self.device)
self.qnet_target.eval()
self.optimizer = optim.RMSprop(self.qnet_eval.parameters(), lr=self.lr)
def choose_action(self, state, epsilon=0):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
actions_value = self.qnet_eval.forward(state)
if np.random.uniform() > epsilon: # greedy
action = torch.max(actions_value, 1)[1].data.cpu().numpy()[0]
else: # random
action = np.random.randint(0, self.n_actions)
return action
def learn(self):
# TODO(Lab-5): DQN core algorithm.
# check to replace target para.
if self.learn_step_counter % self.replace_targer_iter == 0:
self.qnet_target.load_state_dict(self.qnet_eval.state_dict())
# sample batch memory from all memory
if self.memory_counter > self.memory_size:
sample_index = np.random.choice(self.memory_size, size=self.batch_size)
else:
sample_index = np.random.choice(self.memory_counter, size=self.batch.batch_size)
b_s = torch.FloatTensor(self.memory["s"][sample_index]).to(self.device)
b_a = torch.LongTensor(self.memory["a"][sample_index]).to(self.device)
b_e = torch.FloatTensor(self.memory["r"][sample_index]).to(self.device)
b_s_ = torch.FloatTensor(self.memory["s_"][sample_index]).to(self.device)
b_d = torch.FloatTensor(self.memory["done"][sample_index]).to(self.device)
# TD-error
q_curr_eval = self.qnet_eval(b_s).gather(1, b_a)
q_next_target = self.qnet_target(b_s_).detach()
next_state_values = q_next_target.max(1)[0].view(-1, 1)
q_curr_recur = b_r + (1-b_d) * self.gamma * next_state_values
q_curr_eval = self.qnet_eval(b_s).gather(1, b_a)
q_next_target = self.qnet_target(b_s_).detach()
q_next_eval = self.qnet_eval(b_s_).detach()
next_state_values = q_next_target.gather(1, q_next_eval.max(1)[1].unsqueeze(1))
q_curr_recur = b_r + (1-b_d) * self.gamma * next_state_values
# Compute the loss and optimize the network
self.loss = F.smooth_l1_loss(q_curr_eval, q_curr_recur)
self.optimizer.zero_grad()
self.loss.backward()
self.optimizer.step()
self.learn_step_counter += 1
return self.loss.detach().cpu().numpy()
# replayed buffer
def init_memory(self):
self.memory = {
"s": np.zeros((self.memory_size, *self.input_shape)),
"a": np.zeros((self.memory_size, 1)),
"r": np.zeros((self.memory_size, 1)),
"s_": np.zeros((self.memory_size, *self.input_shape)),
"done": np.zeros((self.memory_size, 1)),
}
def store_transition(self, s, a, r, s_, d):
if not hasattr(self, 'memory_counter'):
self.memory_counter = 0
if self.memory_counter <= self.memory_size:
index = self.memory_counter % self.memory_size
else:
index = np.random.randint(self.memory_size)
self.memory["s"][index] = s
self.memory["a"][index] = np.array(a).reshape(-1,1)
self.memory["r"][index] = np.array(r).reshape(-1,1)
self.memory["s_"][index] = s_
self.memory["done"][index] = np.array(d).reshape(-1,1)
self.memory_counter += 1
def save_load_model(self, op, path="save", fname="qnet.pt"):
import os
if not os.path.exists(path):
os.makedirs(path)
file_path = os.path.join(path, fname)
if op == "save":
torch.save(self.qnet_eval.state_dict(), file_path)
elif op == "load":
self.qnet_eval.load_state_dict(torch.load(file_path, map_location=self.device))
self.qnet_target.load_state_dict(torch.load(file_path, map_location=self.device))
Editor is loading...