Untitled
unknown
plain_text
a year ago
5.7 kB
10
Indexable
import gym
from gym import spaces
import numpy as np
import pygame
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
class RobotEnv(gym.Env):
def __init__(self):
super(RobotEnv, self).__init__()
self.grid_size = 5
self.action_space = spaces.Discrete(4) # 0: Up, 1: Right, 2: Down, 3: Left
self.observation_space = spaces.Box(
low=0, high=1, shape=(self.grid_size, self.grid_size, 3), dtype=np.float32
)
self.robot_pos = None
self.packages = None
self.steps = 0
self.max_steps = 100
# Pygame setup
self.cell_size = 100
self.screen = pygame.display.set_mode(
(self.grid_size * self.cell_size, self.grid_size * self.cell_size)
)
self.robot_img = pygame.image.load("images/robot.png")
self.package_img = pygame.image.load("images/package.png")
self.robot_img = pygame.transform.scale(
self.robot_img, (self.cell_size, self.cell_size)
)
self.package_img = pygame.transform.scale(
self.package_img, (self.cell_size, self.cell_size)
)
def reset(self):
self.robot_pos = np.random.randint(0, self.grid_size, size=2)
self.packages = np.random.randint(0, self.grid_size, size=(3, 2)) # 3 packages
self.steps = 0
return self._get_obs()
def step(self, action):
self.steps += 1
# Move robot
if action == 0 and self.robot_pos[0] > 0:
self.robot_pos[0] -= 1
elif action == 1 and self.robot_pos[1] < self.grid_size - 1:
self.robot_pos[1] += 1
elif action == 2 and self.robot_pos[0] < self.grid_size - 1:
self.robot_pos[0] += 1
elif action == 3 and self.robot_pos[1] > 0:
self.robot_pos[1] -= 1
# Check if robot collected a package
collected = False
for i, package in enumerate(self.packages):
if np.array_equal(self.robot_pos, package):
self.packages = np.delete(self.packages, i, axis=0)
collected = True
break
# Calculate reward
if collected:
reward = 10
else:
reward = -1 # Small penalty for each step to encourage efficiency
# Check if done
done = len(self.packages) == 0 or self.steps >= self.max_steps
return self._get_obs(), reward, done, {}
def _get_obs(self):
obs = np.zeros((self.grid_size, self.grid_size, 3), dtype=np.float32)
obs[self.robot_pos[0], self.robot_pos[1], 0] = 1 # Robot position
for package in self.packages:
obs[package[0], package[1], 1] = 1 # Package positions
return obs
def render(self):
self.screen.fill((255, 255, 255))
# Draw grid lines
for i in range(self.grid_size):
pygame.draw.line(
self.screen,
(200, 200, 200),
(0, i * self.cell_size),
(self.grid_size * self.cell_size, i * self.cell_size),
)
pygame.draw.line(
self.screen,
(200, 200, 200),
(i * self.cell_size, 0),
(i * self.cell_size, self.grid_size * self.cell_size),
)
# Draw packages
for package in self.packages:
self.screen.blit(
self.package_img,
(package[1] * self.cell_size, package[0] * self.cell_size),
)
# Draw robot
self.screen.blit(
self.robot_img,
(self.robot_pos[1] * self.cell_size, self.robot_pos[0] * self.cell_size),
)
pygame.display.flip()
def train_model(total_timesteps=1000000):
env = DummyVecEnv([lambda: RobotEnv()])
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=total_timesteps)
model.save("robot_model")
return model
def test_model(model, episodes=500):
env = RobotEnv()
pygame.time.wait(1000) # Give Pygame time to initialize
try:
for episode in range(episodes):
obs = env.reset()
done = False
total_reward = 0
while not done:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
total_reward += reward
env.render()
pygame.display.update() # Force update the display
pygame.time.wait(100) # Slow down the rendering
# Handle Pygame events to keep the window responsive
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
return
print(f"Episode {episode + 1}: Total Reward = {total_reward}")
except Exception as e:
print(f"An error occurred during testing: {e}")
finally:
pygame.quit()
if __name__ == "__main__":
try:
pygame.init()
print("Pygame initialized successfully.")
print(pygame.display.Info()) # Print Pygame display info
except pygame.error as e:
print(f"Failed to initialize Pygame: {e}")
exit(1)
# Uncomment the following line to train a new model
model = train_model()
# Load a pre-trained model
try:
model = PPO.load("robot_model")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
exit(1)
# Test the model
test_model(model)
# Ensure Pygame quits properly
pygame.quit()
Editor is loading...
Leave a Comment