Untitled
unknown
plain_text
9 months ago
3.6 kB
10
Indexable
import numpy as np
import matplotlib.pyplot as plt
import random
class Node:
def __init__(self, x, y, theta=0):
self.x = x
self.y = y
self.theta = theta
self.parent = None
class RRT:
def __init__(self, start, goal, map_size, step_size, max_iter, wheel_base):
self.start = Node(*start)
self.goal = Node(*goal)
self.map_size = map_size
self.step_size = step_size
self.max_iter = max_iter
self.wheel_base = wheel_base
self.nodes = [self.start]
def distance(self, node1, node2):
return np.sqrt((node1.x - node2.x) ** 2 + (node1.y - node2.y) ** 2)
def get_random_node(self):
if random.random() < 0.1:
return self.goal
return Node(random.uniform(0, self.map_size[0]), random.uniform(0, self.map_size[1]))
def get_nearest_node(self, random_node):
return min(self.nodes, key=lambda node: self.distance(node, random_node))
def steer(self, from_node, to_node, omega_left, omega_right, wheel_radius):
VL = omega_left * wheel_radius
VR = omega_right * wheel_radius
Vcr = (VL + VR) / 2
omega = (VR - VL) / self.wheel_base
new_theta = from_node.theta + omega * self.step_size
new_x = from_node.x + Vcr * np.cos(new_theta) * self.step_size
new_y = from_node.y + Vcr * np.sin(new_theta) * self.step_size
new_node = Node(new_x, new_y, new_theta)
new_node.parent = from_node
return new_node
def is_goal_reached(self, node):
return self.distance(node, self.goal) <= self.step_size
def get_path(self):
path = []
node = self.nodes[-1]
while node is not None:
path.append((node.x, node.y))
node = node.parent
return path[::-1]
def plan(self, omega_left, omega_right, wheel_radius):
for _ in range(self.max_iter):
rand_node = self.get_random_node()
nearest_node = self.get_nearest_node(rand_node)
new_node = self.steer(nearest_node, rand_node, omega_left, omega_right, wheel_radius)
self.nodes.append(new_node)
if self.is_goal_reached(new_node):
print("Goal reached!")
self.goal.parent = new_node
self.nodes.append(self.goal)
return self.get_path()
print("Max iterations reached, goal not found.")
return None
def draw_map(self, path=None):
plt.figure(figsize=(8, 8))
plt.xlim(0, self.map_size[0])
plt.ylim(0, self.map_size[1])
for node in self.nodes:
if node.parent:
plt.plot([node.x, node.parent.x], [node.y, node.parent.y], "b-")
if path:
plt.plot(*zip(*path), "r-", linewidth=2, label="Path")
plt.scatter(self.start.x, self.start.y, color="green", s=100, label="Start")
plt.scatter(self.goal.x, self.goal.y, color="red", s=100, label="Goal")
plt.legend()
plt.show()
if __name__ == "__main__":
start = (10, 10, 0)
goal = (90, 90, 0)
map_size = (100, 100)
step_size = 1
max_iter = 1000
wheel_base = 1 # Distance between left and right wheel
wheel_radius = 0.1
omega_left = np.pi / 3 # Example angular velocity
omega_right = np.pi / 2
rrt = RRT(start, goal, map_size, step_size, max_iter, wheel_base)
path = rrt.plan(omega_left, omega_right, wheel_radius)
rrt.draw_map(path)
Editor is loading...
Leave a Comment