Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
1.1 kB
3
Indexable
Never
    def select_action(self, state, valid_actions):
        valid_mask = torch.zeros(self.action_dim).to(device)
        for a in valid_actions:
            valid_mask[a] = 1.0

        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action_probs = self.policy_old.actor(state)

            actions = torch.arange(self.action_dim).to(device) # (n_action,)
            state_ = state.view(1, -1).repeat(self.action_dim, 1) # (n_action, state_dim)

            action_probs_ = action_probs.clone()
            if valid_mask.sum().item() > 0:
                action_probs_[valid_mask == 0] = -1e10

            action_probs_ = F.softmax(action_probs_)
            dist = Categorical(action_probs_)
            dist2 = Categorical(action_probs)


        action = dist.sample()
        action_logprob = dist2.log_prob(action)
        self.buffer.states.append(state)
        self.buffer.actions.append(action)
        self.buffer.logprobs.append(action_logprob)
        return action.item()