Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
2.9 kB
3
Indexable
Never
import plotly.graph_objs as go
from collections import deque

class Node:
    def __init__(self, val, children=None, parent=None, num_descendants=None):
        self.val = val
        self.children = children if children else []
        self.parent = parent
        self.num_descendants = num_descendants if num_descendants else 0

def bfs(root):
    queue = deque([root])
    while queue:
        node = queue.popleft()
        for child in node.children:
            child.parent = node
            queue.append(child)

        if node.parent:
            node.num_descendants = node.parent.num_descendants + len(node.children)

def create_node_trace(node, x_pos, y_pos):
    trace = go.Scatter(
        x=[x_pos],
        y=[y_pos],
        text=node.val,
        mode='markers+text',
        marker=dict(
            symbol='square',
            size=30,
            line=dict(width=2, color='blue'),
            color='white'
        ),
        hoverinfo='text'
    )
    return trace

def create_edge_trace(x1, y1, x2, y2):
    trace = go.Scatter(
        x=[x1, x2],
        y=[y1, y2],
        mode='lines',
        line=dict(color='blue', width=2)
    )
    return trace

def create_tree_figure(root):
    bfs(root)

    node_traces = []
    edge_traces = []
    node_queue = deque([(root, 0)])
    max_descendants = root.num_descendants

    while node_queue:
        node, level = node_queue.popleft()
        node_trace = create_node_trace(node, node.num_descendants - max_descendants, -level)
        node_traces.append(node_trace)

        if node.parent:
            edge_trace = create_edge_trace(
                node.parent.num_descendants - max_descendants,
                -level + 1,
                node.num_descendants - max_descendants,
                -level
            )
            edge_traces.append(edge_trace)

        for i, child in enumerate(node.children):
            node_queue.append((child, level + 1))

    x_axis_range = [-max_descendants - 1, max_descendants + 1]
    y_axis_range = [0, -level - 1]

    layout = go.Layout(
        title='Tree Visualization',
        xaxis=dict(
            range=x_axis_range,
            showgrid=False,
            zeroline=False,
            showticklabels=False,
        ),
        yaxis=dict(
            range=y_axis_range,
            showgrid=False,
            zeroline=False,
            showticklabels=False,
            scaleanchor='x',
            scaleratio=1
        ),
        hovermode='closest'
    )

    fig = go.Figure(data=edge_traces + node_traces, layout=layout)
    return fig

# przykładowe drzewo z trzema dziećmi
root = Node(1)
root.children = [Node(2), Node(3), Node(4)]
root.children[0].children = [Node(5), Node(6)]
root.children[1].children = [Node(7), Node(8)]
root.children[2].children = [Node(9), Node(10), Node(11)]

# wizualizacja drzewa
fig = create_tree_figure(root)
fig.show()