Untitled

 avatar
unknown
python
10 months ago
2.3 kB
3
Indexable
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Define the network
class Net(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, state_size)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = F.relu(self.fc1(x))
        next_state_probs = F.softmax(self.fc2(x), dim=-1)
        return next_state_probs

# Define the accuracy function
def accuracy(preds, y):
    _, predictions = preds.max(1)
    correct = (predictions == y).float()
    acc = correct.sum() / len(correct)
    return acc

# Initialize the environment and the network
env = gym.make('CliffWalking-v0')
state_size = env.observation_space.ndim
action_size = env.action_space.n
net = Net(state_size, action_size)

# Define the optimizer and the loss function
optimizer = optim.Adam(net.parameters())
loss_fn = nn.CrossEntropyLoss()

# Training loop
for episode in range(1000):  # adjust as needed
    state = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()  # replace with your action selection method
        next_state, reward, done, info = env.step(action)
        print("next_state = ", next_state)

        # Prepare the data
        state_tensor = torch.tensor([state], dtype=torch.float32)
        action_tensor = torch.tensor([action], dtype=torch.float32)
        next_state_tensor = torch.tensor([next_state], dtype=torch.long)

        # Forward pass

        preds = net(state_tensor, action_tensor)

        print("preds = ", preds)
        print("next_state_tensor = ", next_state_tensor)

        # Calculate the loss and the accuracy
        loss = loss_fn(preds, next_state_tensor)
        acc = accuracy(preds, next_state_tensor)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print the loss and the accuracy
        print(f'Loss: {loss.item()}, Accuracy: {acc.item()}')

        # Move to the next state
        state = next_state
Editor is loading...
Leave a Comment