Untitled
unknown
plain_text
3 years ago
1.1 kB
7
Indexable
def select_action(self, state, valid_actions):
valid_mask = torch.zeros(self.action_dim).to(device)
for a in valid_actions:
valid_mask[a] = 1.0
with torch.no_grad():
state = torch.FloatTensor(state).to(device)
action_probs = self.policy_old.actor(state)
actions = torch.arange(self.action_dim).to(device) # (n_action,)
state_ = state.view(1, -1).repeat(self.action_dim, 1) # (n_action, state_dim)
action_probs_ = action_probs.clone()
if valid_mask.sum().item() > 0:
action_probs_[valid_mask == 0] = -1e10
action_probs_ = F.softmax(action_probs_)
dist = Categorical(action_probs_)
dist2 = Categorical(action_probs)
action = dist.sample()
action_logprob = dist2.log_prob(action)
self.buffer.states.append(state)
self.buffer.actions.append(action)
self.buffer.logprobs.append(action_logprob)
return action.item()Editor is loading...