Untitled

mail@pastecode.io avatar
unknown
plain_text
19 days ago
20 kB
2
Indexable
Never
import dash
import json
import base64
import tempfile
import hashlib
from dash import dcc, html
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
import random
from typing import Dict, List

from flash_core.video_evaluation.htx.evaluation_pipeline import (
    HTXEvaluationPipeline,
)

# Constants
CATEGORIES_ACTION_OF_INTEREST = [
    "falling",
    "fighting",
    "self-harm",
    "none",
    "all_actions",
]
COLORS = [
    "#FF0000",
    "#00FF00",
    "#0000FF",
    "#FFA500",
    "#800080",
    "#00FFFF",
    "#FFD700",
    "#008080",
    "#800000",
    "#008000",
]

htx_pipeline = HTXEvaluationPipeline()


def get_metrics_result_from_vendor_prediction(results_prediction_path) -> Dict:
    action_of_interest_results = htx_pipeline.evaluate(
        results_prediction_path, task="action_of_interest"
    )["all_video"]["result"]
    action_recognition_results = htx_pipeline.evaluate(
        results_prediction_path, task="action_recognition"
    )["all_video"]["result"]

    detection_rate = {
        action: action_of_interest_results[action]["detection_rate"] * 100
        for action in action_of_interest_results
    }
    false_positive_rate = {
        action: action_of_interest_results[action]["false_positive_rate"] * 100
        for action in action_of_interest_results
    }
    accuracy = {
        action: action_recognition_results[action]["accuracy"] * 100
        for action in action_recognition_results
    }

    return detection_rate, false_positive_rate, accuracy


# Chart creation functions
def create_radar_chart(data, data_type, categories):
    fig = go.Figure()

    # Add dashed lines from center to vertices
    for category in categories:
        fig.add_trace(
            go.Scatterpolar(
                r=[0, 100],
                theta=[category, category],
                mode="lines",
                line=dict(color="gray", width=1, dash="dash"),
                showlegend=False,
                hoverinfo="skip",
            )
        )

    # Add data traces with custom colors
    for vendor_data, color in zip(data, COLORS):
        values = [vendor_data[data_type][cat] for cat in categories]
        fig.add_trace(
            go.Scatterpolar(
                r=values + values[:1],
                theta=categories + [categories[0]],
                name=vendor_data["name"],
                fill="toself",
                line=dict(color=color, width=3),
                opacity=0.7,
                hovertemplate=(
                    "<b>%{theta}</b><br>"
                    f"{data_type.replace('_', ' ').title()}: %{{r}}<br>"  # Updated label
                    # "Action: %{theta}"  # Custom angular label
                    "<extra></extra>"
                ),
            )
        )

    # Update layout
    title = (
        "Accuracy"
        if data_type == "accuracy"
        else (
            "Detection Rate"
            if data_type == "detection_rate"
            else "False Positive Rate"
        )
    )
    annotation = (
        "Higher is better"
        if data_type in ["detection_rate", "accuracy"]
        else "Lower is better"
    )
    fig.update_polars(angularaxis_type="category")
    fig.update_layout(
        polar=dict(
            radialaxis=dict(
                visible=True,
                range=[0, 100],
                tickfont=dict(size=14),
                tickangle=30,
                tickvals=[20, 40, 60, 80, 100],
                ticktext=["20", "40", "60", "80", "100"],
            ),
            angularaxis=dict(
                tickfont=dict(size=22, color="red"),
                rotation=45,
                direction="clockwise",
            ),
            bgcolor="rgba(240, 240, 240, 0.8)",
            gridshape="circular",
        ),
        showlegend=True,
        legend=dict(
            font=dict(size=22),
            orientation="h",
            yanchor="bottom",
            y=-0.1,
            xanchor="center",
            x=0.5,
        ),
        title=dict(
            text=f"{title} - Performance Comparison",
            font=dict(size=30, family="Arial", color="black"),
            x=0.5,
            y=0.96,
        ),
        height=1000,
        margin=dict(t=100, b=80, l=80, r=80),
    )

    fig.add_annotation(
        x=0.7,
        y=-0.03,
        xref="paper",
        yref="paper",
        text=annotation,
        showarrow=False,
        font=dict(size=18),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1,
        borderpad=4,
        opacity=0.8,
    )

    return fig


def create_bar_chart(data, data_type, categories):
    fig = go.Figure()

    for i, vendor_data in enumerate(data):
        values = [vendor_data[data_type].get(cat, 0) for cat in categories]
        fig.add_trace(
            go.Bar(
                x=categories,
                y=values,
                name=vendor_data["name"],
                marker_color=COLORS[i % len(COLORS)],
                opacity=0.7,
            )
        )

    title = (
        "Accuracy"
        if data_type == "accuracy"
        else (
            "Detection Rate"
            if data_type == "detection_rate"
            else "False Positive Rate"
        )
    )
    annotation = (
        "Higher is better"
        if data_type in ["detection_rate", "accuracy"]
        else "Lower is better"
    )

    fig.update_layout(
        title=dict(
            text=f"{title} - Performance Comparison",
            font=dict(size=30, family="Arial", color="black"),
            x=0.5,
            y=0.96,
        ),
        xaxis=dict(
            title="Categories", titlefont=dict(size=24), tickfont=dict(size=14)
        ),
        yaxis=dict(
            title=title,
            titlefont=dict(size=16),
            tickfont=dict(size=14),
            range=[0, 100],
            tickvals=[20, 40, 60, 80, 100],
            ticktext=["20", "40", "60", "80", "100"],
        ),
        barmode="group",
        bargap=0.25,
        bargroupgap=0.05,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.2,
            xanchor="center",
            x=0.5,
            font=dict(size=22),
        ),
        annotations=[
            dict(
                x=0.85,
                y=-0.1,
                xref="paper",
                yref="paper",
                text=annotation,
                showarrow=False,
                bgcolor="white",
                bordercolor="black",
                borderwidth=1,
                borderpad=4,
                font=dict(size=18),
            )
        ],
        height=1000,
        margin=dict(t=120, b=80, l=0, r=80),
        plot_bgcolor="rgba(240, 240, 240, 0.8)",
    )

    fig.update_xaxes(
        tickfont=dict(color="black"),
        tickvals=categories,
        ticktext=[
            f'<span style="color: red; font-weight: bold;">{cat}</span>'
            for cat in categories
        ],
    )

    fig.update_yaxes(gridcolor="lightgray", gridwidth=1, griddash="dash")
    fig.update_xaxes(gridcolor="lightgray", gridwidth=1, griddash="dash")

    return fig


# App layout components
def create_sidebar():
    return html.Div(
        [
            dcc.Upload(
                id="upload-data",
                children=html.Div(
                    [
                        html.P(
                            "Upload Files \n or \n Drag and Drop\nhere",
                            style={
                                "whiteSpace": "pre-line",
                                "color": "#007bff",
                                "fontWeight": "bold",
                            },
                        )
                    ]
                ),
                style={
                    "width": "88%",
                    "height": "150px",
                    "lineHeight": "30px",
                    "borderWidth": "3px",
                    "borderStyle": "groove",
                    "borderColor": "#bf1f24",
                    "borderRadius": "20px",
                    "textAlign": "center",
                    "backgroundColor": "#ffffff",
                    "margin": "10px",
                },
                multiple=True,
            ),
            html.H2("Ingested Files"),
            html.Ul(
                id="file-list",
                style={
                    "listStyleType": "none",
                    "fontSize": "18px",
                    "padding": "0",
                    "backgroundColor": "#e9ecef",
                    "borderRadius": "10px",
                    "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.1)",
                    "lineHeight": "30px",
                },
            ),
            html.H2("Select Task"),
            dcc.Dropdown(
                id="task-selector",
                options=[
                    {
                        "label": "Action of Interest",
                        "value": "action_of_interest",
                    },
                    {
                        "label": "Action Recognition",
                        "value": "action_recognition",
                    },
                ],
                value="action_of_interest",
                style={"marginBottom": "20px"},
            ),
        ],
        style={
            "width": "9%",
            "background-color": "#f2f2f2",
            "padding": "2px",
            "position": "fixed",
            "height": "100%",
            "overflow": "auto",
        },
    )


def create_main_content():
    return html.Div(
        [
            html.H1(
                id="main-title",
                style={
                    "textAlign": "center",
                    "fontSize": "40px",
                    "marginBottom": "20px",
                    "marginLeft": "10%",
                },
            ),
            html.Div(
                dcc.Dropdown(
                    id="chart-type-dropdown",
                    options=[
                        {"label": "Radar Chart", "value": "radar"},
                        {"label": "Bar Chart", "value": "bar"},
                    ],
                    value="radar",
                    style={"width": "50%", "marginLeft": "28%"},
                )
            ),
            html.Div(id="action-of-interest-data", style={"display": "none"}),
            html.Div(id="action-recognition-data", style={"display": "none"}),
            html.Div(
                id="chart-container",
                style={
                    "marginLeft": "10%",
                    "display": "flex",
                    "flexWrap": "wrap",
                },
            ),
        ]
    )


# Initialize the Dash app
app = dash.Dash(__name__)

# Add this to store uploaded filenames
app.layout = html.Div(
    [
        dcc.Store(id="uploaded-filenames", storage_type="memory"),
        dcc.Store(id="file-status", storage_type="memory"),
        create_sidebar(),
        create_main_content(),
    ]
)


@app.callback(
    Output("file-list", "children"),
    Output("uploaded-filenames", "data"),
    Output("file-status", "data"),
    [Input("upload-data", "filename")],
    [State("uploaded-filenames", "data"), State("file-status", "data")],
)
def update_file_list(new_filenames, existing_filenames, existing_status):
    if existing_filenames is None:
        existing_filenames = []
    if existing_status is None:
        existing_status = {}

    if new_filenames:
        updated_filenames = existing_filenames + new_filenames
        unique_filenames = list(
            dict.fromkeys(updated_filenames)
        )  # Remove duplicates while preserving order

        # Initialize new files with 'processing' status
        for filename in new_filenames:
            if filename not in existing_status:
                existing_status[filename] = "processing"

        return (
            [
                html.Li(
                    filename,
                    style={
                        "color": (
                            "orange"
                            if existing_status[filename] == "processing"
                            else (
                                "green"
                                if existing_status[filename] == "success"
                                else "red"
                            )
                        )
                    },
                )
                for filename in unique_filenames
            ],
            unique_filenames,
            existing_status,
        )

    if not existing_filenames:
        return [html.Li("No files uploaded yet")], [], {}

    return (
        [
            html.Li(
                filename,
                style={
                    "color": (
                        "orange"
                        if existing_status[filename] == "processing"
                        else (
                            "green"
                            if existing_status[filename] == "success"
                            else "red"
                        )
                    )
                },
            )
            for filename in existing_filenames
        ],
        existing_filenames,
        existing_status,
    )


@app.callback(
    Output("action-of-interest-data", "children"),
    Output("action-recognition-data", "children"),
    Output("file-status", "data", allow_duplicate=True),
    [Input("upload-data", "contents")],
    [
        State("upload-data", "filename"),
        State("action-of-interest-data", "children"),
        State("action-recognition-data", "children"),
        State("file-status", "data"),
    ],
    prevent_initial_call=True,
)
def update_output(
    new_contents,
    new_filenames,
    existing_interest_data,
    existing_recognition_data,
    file_status,
):
    if not new_contents:
        return existing_interest_data, existing_recognition_data, file_status

    existing_interest_data = (
        json.loads(existing_interest_data) if existing_interest_data else []
    )
    existing_recognition_data = (
        json.loads(existing_recognition_data)
        if existing_recognition_data
        else []
    )

    for contents, filename in zip(new_contents, new_filenames):
        try:
            content_type, content_string = contents.split(",")
            decoded = base64.b64decode(content_string)

            content_hash = hashlib.md5(decoded).hexdigest()
            temp_filename = f"temp_{content_hash}.json"

            with tempfile.NamedTemporaryFile(
                mode="wb", delete=False, suffix=".json", prefix="temp_"
            ) as temp_file:
                temp_file.write(decoded)
                temp_file_path = temp_file.name

            result = htx_pipeline.evaluate(
                vllm_results_path=temp_file_path, task="action_of_interest"
            )
            print("Evaluation result:", result)

            detection_rate, false_positive_rate, accuracy = (
                get_metrics_result_from_vendor_prediction(temp_file_path)
            )
            vendor_name = filename.split(".")[0]

            # Check if vendor already exists in the data
            existing_vendor_interest = next(
                (
                    item
                    for item in existing_interest_data
                    if item["name"] == vendor_name
                ),
                None,
            )
            if existing_vendor_interest:
                existing_vendor_interest.update(
                    {
                        "detection_rate": detection_rate,
                        "false_positive_rate": false_positive_rate,
                    }
                )
            else:
                existing_interest_data.append(
                    {
                        "name": vendor_name,
                        "detection_rate": detection_rate,
                        "false_positive_rate": false_positive_rate,
                    }
                )

            existing_vendor_recognition = next(
                (
                    item
                    for item in existing_recognition_data
                    if item["name"] == vendor_name
                ),
                None,
            )
            if existing_vendor_recognition:
                existing_vendor_recognition.update(
                    {
                        "accuracy": accuracy,
                    }
                )
            else:
                existing_recognition_data.append(
                    {
                        "name": vendor_name,
                        "accuracy": accuracy,
                    }
                )

            file_status[filename] = "success"

        except Exception as e:
            print(f"There was an error processing file {filename}: {e}")
            # Mark file as failed
            file_status[filename] = "failed"

    return (
        json.dumps(existing_interest_data),
        json.dumps(existing_recognition_data),
        file_status,
    )


# Update the main callback
@app.callback(
    Output("chart-container", "children"),
    [
        Input("task-selector", "value"),
        Input("chart-type-dropdown", "value"),
        Input("action-of-interest-data", "children"),
        Input("action-recognition-data", "children"),
    ],
)
def update_charts(
    selected_task,
    selected_chart_type,
    action_of_interest_data,
    action_recognition_data,
):
    if selected_task == "action_of_interest":
        data = (
            json.loads(action_of_interest_data)
            if action_of_interest_data
            else []
        )
        categories = CATEGORIES_ACTION_OF_INTEREST
        chart_func = (
            create_radar_chart
            if selected_chart_type == "radar"
            else create_bar_chart
        )
        chart1 = dcc.Graph(
            figure=chart_func(data, "detection_rate", categories),
            style={"height": "1000px", "width": "100%"},
        )
        chart2 = dcc.Graph(
            figure=chart_func(data, "false_positive_rate", categories),
            style={"height": "1000px", "width": "100%"},
        )
        return [
            html.Div(style={"height": "10px"}),
            html.Div(chart1, style={"width": "100%"}),
            html.Div(style={"height": "1300px"}),
            html.Div(chart2, style={"width": "100%"}),
            html.Div(style={"height": "1200px"}),
        ]
    elif selected_task == "action_recognition":
        data = (
            json.loads(action_recognition_data)
            if action_recognition_data
            else []
        )
        if data:
            # Find categories that are present in all vendor data
            all_categories = set(data[0]["accuracy"].keys())
            for vendor_data in data[1:]:
                all_categories.intersection_update(
                    vendor_data["accuracy"].keys()
                )
            categories = sorted(list(all_categories))
        else:
            categories = []
        chart_func = (
            create_radar_chart
            if selected_chart_type == "radar"
            else create_bar_chart
        )
        chart = dcc.Graph(
            figure=chart_func(data, "accuracy", categories),
            style={"height": "1000px", "width": "100%"},
        )
        return [
            html.Div(chart, style={"height": "1000px", "width": "100%"}),
            html.Div(style={"height": "1000px", "width": "100%"}),
        ]


@app.callback(
    Output("main-title", "children"), [Input("task-selector", "value")]
)
def update_title(selected_task):
    return (
        "ACTION OF INTEREST BENCHMARK"
        if selected_task == "action_of_interest"
        else (
            "ACTION RECOGNITION BENCHMARK"
            if selected_task == "action_recognition"
            else "BENCHMARK"
        )
    )


# Run the app
if __name__ == "__main__":
    app.run_server(debug=True)
Leave a Comment