Untitled

mail@pastecode.io avatar
unknown
python
18 days ago
2.2 kB
19
Indexable
Never
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Define the network
class Net(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(state_size + action_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, state_size)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.sigmoid(self.fc1(x))
        next_state_probs = F.softmax(self.fc2(x), dim=-1)
        return next_state_probs

# Define the accuracy function
def accuracy(preds, y):
    preds = torch.cat(preds)
    y = torch.cat(y)
    correct = (preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

for episode in range(1000):  # adjust as needed
    print("episode = ", episode)
    state = env.reset()
    done = False
    preds_array = []
    next_state_array = []
    while not done:
        action = env.action_space.sample()  # replace with your action selection method
        next_state, reward, done, info = env.step(action)

        # Prepare the data
        state_tensor = F.one_hot(torch.tensor([state]), num_classes=state_size).float()
        action_tensor = F.one_hot(torch.tensor([action]), num_classes=action_size).float()
        next_state_tensor = torch.tensor([next_state])

        # Forward pass
        preds = net(state_tensor, action_tensor)
        preds_array.append(torch.max(preds, 1)[1].unsqueeze(0))
        next_state_array.append(next_state_tensor)

        # Calculate the loss and the accuracy
        loss = loss_fn(preds, next_state_tensor)
        if episode%50==0 and episode!=0:
          print("next_state_array = ", next_state_array)
          print("Length = ", len(preds_array))

          acc = accuracy(preds_array, next_state_array)
          print("The accuracy is ", acc)
          preds_array = []
          next_state_array = []

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Move to the next state
        state = next_state
Leave a Comment