Untitled
unknown
plain_text
3 years ago
1.1 kB
5
Indexable
def select_action(self, state, valid_actions): valid_mask = torch.zeros(self.action_dim).to(device) for a in valid_actions: valid_mask[a] = 1.0 with torch.no_grad(): state = torch.FloatTensor(state).to(device) action_probs = self.policy_old.actor(state) actions = torch.arange(self.action_dim).to(device) # (n_action,) state_ = state.view(1, -1).repeat(self.action_dim, 1) # (n_action, state_dim) action_probs_ = action_probs.clone() if valid_mask.sum().item() > 0: action_probs_[valid_mask == 0] = -1e10 action_probs_ = F.softmax(action_probs_) dist = Categorical(action_probs_) dist2 = Categorical(action_probs) action = dist.sample() action_logprob = dist2.log_prob(action) self.buffer.states.append(state) self.buffer.actions.append(action) self.buffer.logprobs.append(action_logprob) return action.item()
Editor is loading...