Untitled
unknown
plain_text
a month ago
5.7 kB
2
Indexable
Never
import gym from gym import spaces import numpy as np import pygame from stable_baselines3 import PPO from stable_baselines3.common.vec_env import DummyVecEnv class RobotEnv(gym.Env): def __init__(self): super(RobotEnv, self).__init__() self.grid_size = 5 self.action_space = spaces.Discrete(4) # 0: Up, 1: Right, 2: Down, 3: Left self.observation_space = spaces.Box( low=0, high=1, shape=(self.grid_size, self.grid_size, 3), dtype=np.float32 ) self.robot_pos = None self.packages = None self.steps = 0 self.max_steps = 100 # Pygame setup self.cell_size = 100 self.screen = pygame.display.set_mode( (self.grid_size * self.cell_size, self.grid_size * self.cell_size) ) self.robot_img = pygame.image.load("images/robot.png") self.package_img = pygame.image.load("images/package.png") self.robot_img = pygame.transform.scale( self.robot_img, (self.cell_size, self.cell_size) ) self.package_img = pygame.transform.scale( self.package_img, (self.cell_size, self.cell_size) ) def reset(self): self.robot_pos = np.random.randint(0, self.grid_size, size=2) self.packages = np.random.randint(0, self.grid_size, size=(3, 2)) # 3 packages self.steps = 0 return self._get_obs() def step(self, action): self.steps += 1 # Move robot if action == 0 and self.robot_pos[0] > 0: self.robot_pos[0] -= 1 elif action == 1 and self.robot_pos[1] < self.grid_size - 1: self.robot_pos[1] += 1 elif action == 2 and self.robot_pos[0] < self.grid_size - 1: self.robot_pos[0] += 1 elif action == 3 and self.robot_pos[1] > 0: self.robot_pos[1] -= 1 # Check if robot collected a package collected = False for i, package in enumerate(self.packages): if np.array_equal(self.robot_pos, package): self.packages = np.delete(self.packages, i, axis=0) collected = True break # Calculate reward if collected: reward = 10 else: reward = -1 # Small penalty for each step to encourage efficiency # Check if done done = len(self.packages) == 0 or self.steps >= self.max_steps return self._get_obs(), reward, done, {} def _get_obs(self): obs = np.zeros((self.grid_size, self.grid_size, 3), dtype=np.float32) obs[self.robot_pos[0], self.robot_pos[1], 0] = 1 # Robot position for package in self.packages: obs[package[0], package[1], 1] = 1 # Package positions return obs def render(self): self.screen.fill((255, 255, 255)) # Draw grid lines for i in range(self.grid_size): pygame.draw.line( self.screen, (200, 200, 200), (0, i * self.cell_size), (self.grid_size * self.cell_size, i * self.cell_size), ) pygame.draw.line( self.screen, (200, 200, 200), (i * self.cell_size, 0), (i * self.cell_size, self.grid_size * self.cell_size), ) # Draw packages for package in self.packages: self.screen.blit( self.package_img, (package[1] * self.cell_size, package[0] * self.cell_size), ) # Draw robot self.screen.blit( self.robot_img, (self.robot_pos[1] * self.cell_size, self.robot_pos[0] * self.cell_size), ) pygame.display.flip() def train_model(total_timesteps=1000000): env = DummyVecEnv([lambda: RobotEnv()]) model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=total_timesteps) model.save("robot_model") return model def test_model(model, episodes=500): env = RobotEnv() pygame.time.wait(1000) # Give Pygame time to initialize try: for episode in range(episodes): obs = env.reset() done = False total_reward = 0 while not done: action, _ = model.predict(obs, deterministic=True) obs, reward, done, _ = env.step(action) total_reward += reward env.render() pygame.display.update() # Force update the display pygame.time.wait(100) # Slow down the rendering # Handle Pygame events to keep the window responsive for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() return print(f"Episode {episode + 1}: Total Reward = {total_reward}") except Exception as e: print(f"An error occurred during testing: {e}") finally: pygame.quit() if __name__ == "__main__": try: pygame.init() print("Pygame initialized successfully.") print(pygame.display.Info()) # Print Pygame display info except pygame.error as e: print(f"Failed to initialize Pygame: {e}") exit(1) # Uncomment the following line to train a new model model = train_model() # Load a pre-trained model try: model = PPO.load("robot_model") print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") exit(1) # Test the model test_model(model) # Ensure Pygame quits properly pygame.quit()
Leave a Comment