Untitled
unknown
plain_text
3 years ago
2.9 kB
6
Indexable
import plotly.graph_objs as go
from collections import deque
class Node:
def __init__(self, val, children=None, parent=None, num_descendants=None):
self.val = val
self.children = children if children else []
self.parent = parent
self.num_descendants = num_descendants if num_descendants else 0
def bfs(root):
queue = deque([root])
while queue:
node = queue.popleft()
for child in node.children:
child.parent = node
queue.append(child)
if node.parent:
node.num_descendants = node.parent.num_descendants + len(node.children)
def create_node_trace(node, x_pos, y_pos):
trace = go.Scatter(
x=[x_pos],
y=[y_pos],
text=node.val,
mode='markers+text',
marker=dict(
symbol='square',
size=30,
line=dict(width=2, color='blue'),
color='white'
),
hoverinfo='text'
)
return trace
def create_edge_trace(x1, y1, x2, y2):
trace = go.Scatter(
x=[x1, x2],
y=[y1, y2],
mode='lines',
line=dict(color='blue', width=2)
)
return trace
def create_tree_figure(root):
bfs(root)
node_traces = []
edge_traces = []
node_queue = deque([(root, 0)])
max_descendants = root.num_descendants
while node_queue:
node, level = node_queue.popleft()
node_trace = create_node_trace(node, node.num_descendants - max_descendants, -level)
node_traces.append(node_trace)
if node.parent:
edge_trace = create_edge_trace(
node.parent.num_descendants - max_descendants,
-level + 1,
node.num_descendants - max_descendants,
-level
)
edge_traces.append(edge_trace)
for i, child in enumerate(node.children):
node_queue.append((child, level + 1))
x_axis_range = [-max_descendants - 1, max_descendants + 1]
y_axis_range = [0, -level - 1]
layout = go.Layout(
title='Tree Visualization',
xaxis=dict(
range=x_axis_range,
showgrid=False,
zeroline=False,
showticklabels=False,
),
yaxis=dict(
range=y_axis_range,
showgrid=False,
zeroline=False,
showticklabels=False,
scaleanchor='x',
scaleratio=1
),
hovermode='closest'
)
fig = go.Figure(data=edge_traces + node_traces, layout=layout)
return fig
# przykładowe drzewo z trzema dziećmi
root = Node(1)
root.children = [Node(2), Node(3), Node(4)]
root.children[0].children = [Node(5), Node(6)]
root.children[1].children = [Node(7), Node(8)]
root.children[2].children = [Node(9), Node(10), Node(11)]
# wizualizacja drzewa
fig = create_tree_figure(root)
fig.show()
Editor is loading...