Untitled
unknown
plain_text
a year ago
20 kB
13
Indexable
import dash
import json
import base64
import tempfile
import hashlib
from dash import dcc, html
from dash.dependencies import Input, Output, State
import plotly.graph_objects as go
import random
from typing import Dict, List
from flash_core.video_evaluation.htx.evaluation_pipeline import (
HTXEvaluationPipeline,
)
# Constants
CATEGORIES_ACTION_OF_INTEREST = [
"falling",
"fighting",
"self-harm",
"none",
"all_actions",
]
COLORS = [
"#FF0000",
"#00FF00",
"#0000FF",
"#FFA500",
"#800080",
"#00FFFF",
"#FFD700",
"#008080",
"#800000",
"#008000",
]
htx_pipeline = HTXEvaluationPipeline()
def get_metrics_result_from_vendor_prediction(results_prediction_path) -> Dict:
action_of_interest_results = htx_pipeline.evaluate(
results_prediction_path, task="action_of_interest"
)["all_video"]["result"]
action_recognition_results = htx_pipeline.evaluate(
results_prediction_path, task="action_recognition"
)["all_video"]["result"]
detection_rate = {
action: action_of_interest_results[action]["detection_rate"] * 100
for action in action_of_interest_results
}
false_positive_rate = {
action: action_of_interest_results[action]["false_positive_rate"] * 100
for action in action_of_interest_results
}
accuracy = {
action: action_recognition_results[action]["accuracy"] * 100
for action in action_recognition_results
}
return detection_rate, false_positive_rate, accuracy
# Chart creation functions
def create_radar_chart(data, data_type, categories):
fig = go.Figure()
# Add dashed lines from center to vertices
for category in categories:
fig.add_trace(
go.Scatterpolar(
r=[0, 100],
theta=[category, category],
mode="lines",
line=dict(color="gray", width=1, dash="dash"),
showlegend=False,
hoverinfo="skip",
)
)
# Add data traces with custom colors
for vendor_data, color in zip(data, COLORS):
values = [vendor_data[data_type][cat] for cat in categories]
fig.add_trace(
go.Scatterpolar(
r=values + values[:1],
theta=categories + [categories[0]],
name=vendor_data["name"],
fill="toself",
line=dict(color=color, width=3),
opacity=0.7,
hovertemplate=(
"<b>%{theta}</b><br>"
f"{data_type.replace('_', ' ').title()}: %{{r}}<br>" # Updated label
# "Action: %{theta}" # Custom angular label
"<extra></extra>"
),
)
)
# Update layout
title = (
"Accuracy"
if data_type == "accuracy"
else (
"Detection Rate"
if data_type == "detection_rate"
else "False Positive Rate"
)
)
annotation = (
"Higher is better"
if data_type in ["detection_rate", "accuracy"]
else "Lower is better"
)
fig.update_polars(angularaxis_type="category")
fig.update_layout(
polar=dict(
radialaxis=dict(
visible=True,
range=[0, 100],
tickfont=dict(size=14),
tickangle=30,
tickvals=[20, 40, 60, 80, 100],
ticktext=["20", "40", "60", "80", "100"],
),
angularaxis=dict(
tickfont=dict(size=22, color="red"),
rotation=45,
direction="clockwise",
),
bgcolor="rgba(240, 240, 240, 0.8)",
gridshape="circular",
),
showlegend=True,
legend=dict(
font=dict(size=22),
orientation="h",
yanchor="bottom",
y=-0.1,
xanchor="center",
x=0.5,
),
title=dict(
text=f"{title} - Performance Comparison",
font=dict(size=30, family="Arial", color="black"),
x=0.5,
y=0.96,
),
height=1000,
margin=dict(t=100, b=80, l=80, r=80),
)
fig.add_annotation(
x=0.7,
y=-0.03,
xref="paper",
yref="paper",
text=annotation,
showarrow=False,
font=dict(size=18),
bgcolor="white",
bordercolor="black",
borderwidth=1,
borderpad=4,
opacity=0.8,
)
return fig
def create_bar_chart(data, data_type, categories):
fig = go.Figure()
for i, vendor_data in enumerate(data):
values = [vendor_data[data_type].get(cat, 0) for cat in categories]
fig.add_trace(
go.Bar(
x=categories,
y=values,
name=vendor_data["name"],
marker_color=COLORS[i % len(COLORS)],
opacity=0.7,
)
)
title = (
"Accuracy"
if data_type == "accuracy"
else (
"Detection Rate"
if data_type == "detection_rate"
else "False Positive Rate"
)
)
annotation = (
"Higher is better"
if data_type in ["detection_rate", "accuracy"]
else "Lower is better"
)
fig.update_layout(
title=dict(
text=f"{title} - Performance Comparison",
font=dict(size=30, family="Arial", color="black"),
x=0.5,
y=0.96,
),
xaxis=dict(
title="Categories", titlefont=dict(size=24), tickfont=dict(size=14)
),
yaxis=dict(
title=title,
titlefont=dict(size=16),
tickfont=dict(size=14),
range=[0, 100],
tickvals=[20, 40, 60, 80, 100],
ticktext=["20", "40", "60", "80", "100"],
),
barmode="group",
bargap=0.25,
bargroupgap=0.05,
legend=dict(
orientation="h",
yanchor="bottom",
y=-0.2,
xanchor="center",
x=0.5,
font=dict(size=22),
),
annotations=[
dict(
x=0.85,
y=-0.1,
xref="paper",
yref="paper",
text=annotation,
showarrow=False,
bgcolor="white",
bordercolor="black",
borderwidth=1,
borderpad=4,
font=dict(size=18),
)
],
height=1000,
margin=dict(t=120, b=80, l=0, r=80),
plot_bgcolor="rgba(240, 240, 240, 0.8)",
)
fig.update_xaxes(
tickfont=dict(color="black"),
tickvals=categories,
ticktext=[
f'<span style="color: red; font-weight: bold;">{cat}</span>'
for cat in categories
],
)
fig.update_yaxes(gridcolor="lightgray", gridwidth=1, griddash="dash")
fig.update_xaxes(gridcolor="lightgray", gridwidth=1, griddash="dash")
return fig
# App layout components
def create_sidebar():
return html.Div(
[
dcc.Upload(
id="upload-data",
children=html.Div(
[
html.P(
"Upload Files \n or \n Drag and Drop\nhere",
style={
"whiteSpace": "pre-line",
"color": "#007bff",
"fontWeight": "bold",
},
)
]
),
style={
"width": "88%",
"height": "150px",
"lineHeight": "30px",
"borderWidth": "3px",
"borderStyle": "groove",
"borderColor": "#bf1f24",
"borderRadius": "20px",
"textAlign": "center",
"backgroundColor": "#ffffff",
"margin": "10px",
},
multiple=True,
),
html.H2("Ingested Files"),
html.Ul(
id="file-list",
style={
"listStyleType": "none",
"fontSize": "18px",
"padding": "0",
"backgroundColor": "#e9ecef",
"borderRadius": "10px",
"boxShadow": "0 2px 4px rgba(0, 0, 0, 0.1)",
"lineHeight": "30px",
},
),
html.H2("Select Task"),
dcc.Dropdown(
id="task-selector",
options=[
{
"label": "Action of Interest",
"value": "action_of_interest",
},
{
"label": "Action Recognition",
"value": "action_recognition",
},
],
value="action_of_interest",
style={"marginBottom": "20px"},
),
],
style={
"width": "9%",
"background-color": "#f2f2f2",
"padding": "2px",
"position": "fixed",
"height": "100%",
"overflow": "auto",
},
)
def create_main_content():
return html.Div(
[
html.H1(
id="main-title",
style={
"textAlign": "center",
"fontSize": "40px",
"marginBottom": "20px",
"marginLeft": "10%",
},
),
html.Div(
dcc.Dropdown(
id="chart-type-dropdown",
options=[
{"label": "Radar Chart", "value": "radar"},
{"label": "Bar Chart", "value": "bar"},
],
value="radar",
style={"width": "50%", "marginLeft": "28%"},
)
),
html.Div(id="action-of-interest-data", style={"display": "none"}),
html.Div(id="action-recognition-data", style={"display": "none"}),
html.Div(
id="chart-container",
style={
"marginLeft": "10%",
"display": "flex",
"flexWrap": "wrap",
},
),
]
)
# Initialize the Dash app
app = dash.Dash(__name__)
# Add this to store uploaded filenames
app.layout = html.Div(
[
dcc.Store(id="uploaded-filenames", storage_type="memory"),
dcc.Store(id="file-status", storage_type="memory"),
create_sidebar(),
create_main_content(),
]
)
@app.callback(
Output("file-list", "children"),
Output("uploaded-filenames", "data"),
Output("file-status", "data"),
[Input("upload-data", "filename")],
[State("uploaded-filenames", "data"), State("file-status", "data")],
)
def update_file_list(new_filenames, existing_filenames, existing_status):
if existing_filenames is None:
existing_filenames = []
if existing_status is None:
existing_status = {}
if new_filenames:
updated_filenames = existing_filenames + new_filenames
unique_filenames = list(
dict.fromkeys(updated_filenames)
) # Remove duplicates while preserving order
# Initialize new files with 'processing' status
for filename in new_filenames:
if filename not in existing_status:
existing_status[filename] = "processing"
return (
[
html.Li(
filename,
style={
"color": (
"orange"
if existing_status[filename] == "processing"
else (
"green"
if existing_status[filename] == "success"
else "red"
)
)
},
)
for filename in unique_filenames
],
unique_filenames,
existing_status,
)
if not existing_filenames:
return [html.Li("No files uploaded yet")], [], {}
return (
[
html.Li(
filename,
style={
"color": (
"orange"
if existing_status[filename] == "processing"
else (
"green"
if existing_status[filename] == "success"
else "red"
)
)
},
)
for filename in existing_filenames
],
existing_filenames,
existing_status,
)
@app.callback(
Output("action-of-interest-data", "children"),
Output("action-recognition-data", "children"),
Output("file-status", "data", allow_duplicate=True),
[Input("upload-data", "contents")],
[
State("upload-data", "filename"),
State("action-of-interest-data", "children"),
State("action-recognition-data", "children"),
State("file-status", "data"),
],
prevent_initial_call=True,
)
def update_output(
new_contents,
new_filenames,
existing_interest_data,
existing_recognition_data,
file_status,
):
if not new_contents:
return existing_interest_data, existing_recognition_data, file_status
existing_interest_data = (
json.loads(existing_interest_data) if existing_interest_data else []
)
existing_recognition_data = (
json.loads(existing_recognition_data)
if existing_recognition_data
else []
)
for contents, filename in zip(new_contents, new_filenames):
try:
content_type, content_string = contents.split(",")
decoded = base64.b64decode(content_string)
content_hash = hashlib.md5(decoded).hexdigest()
temp_filename = f"temp_{content_hash}.json"
with tempfile.NamedTemporaryFile(
mode="wb", delete=False, suffix=".json", prefix="temp_"
) as temp_file:
temp_file.write(decoded)
temp_file_path = temp_file.name
result = htx_pipeline.evaluate(
vllm_results_path=temp_file_path, task="action_of_interest"
)
print("Evaluation result:", result)
detection_rate, false_positive_rate, accuracy = (
get_metrics_result_from_vendor_prediction(temp_file_path)
)
vendor_name = filename.split(".")[0]
# Check if vendor already exists in the data
existing_vendor_interest = next(
(
item
for item in existing_interest_data
if item["name"] == vendor_name
),
None,
)
if existing_vendor_interest:
existing_vendor_interest.update(
{
"detection_rate": detection_rate,
"false_positive_rate": false_positive_rate,
}
)
else:
existing_interest_data.append(
{
"name": vendor_name,
"detection_rate": detection_rate,
"false_positive_rate": false_positive_rate,
}
)
existing_vendor_recognition = next(
(
item
for item in existing_recognition_data
if item["name"] == vendor_name
),
None,
)
if existing_vendor_recognition:
existing_vendor_recognition.update(
{
"accuracy": accuracy,
}
)
else:
existing_recognition_data.append(
{
"name": vendor_name,
"accuracy": accuracy,
}
)
file_status[filename] = "success"
except Exception as e:
print(f"There was an error processing file {filename}: {e}")
# Mark file as failed
file_status[filename] = "failed"
return (
json.dumps(existing_interest_data),
json.dumps(existing_recognition_data),
file_status,
)
# Update the main callback
@app.callback(
Output("chart-container", "children"),
[
Input("task-selector", "value"),
Input("chart-type-dropdown", "value"),
Input("action-of-interest-data", "children"),
Input("action-recognition-data", "children"),
],
)
def update_charts(
selected_task,
selected_chart_type,
action_of_interest_data,
action_recognition_data,
):
if selected_task == "action_of_interest":
data = (
json.loads(action_of_interest_data)
if action_of_interest_data
else []
)
categories = CATEGORIES_ACTION_OF_INTEREST
chart_func = (
create_radar_chart
if selected_chart_type == "radar"
else create_bar_chart
)
chart1 = dcc.Graph(
figure=chart_func(data, "detection_rate", categories),
style={"height": "1000px", "width": "100%"},
)
chart2 = dcc.Graph(
figure=chart_func(data, "false_positive_rate", categories),
style={"height": "1000px", "width": "100%"},
)
return [
html.Div(style={"height": "10px"}),
html.Div(chart1, style={"width": "100%"}),
html.Div(style={"height": "1300px"}),
html.Div(chart2, style={"width": "100%"}),
html.Div(style={"height": "1200px"}),
]
elif selected_task == "action_recognition":
data = (
json.loads(action_recognition_data)
if action_recognition_data
else []
)
if data:
# Find categories that are present in all vendor data
all_categories = set(data[0]["accuracy"].keys())
for vendor_data in data[1:]:
all_categories.intersection_update(
vendor_data["accuracy"].keys()
)
categories = sorted(list(all_categories))
else:
categories = []
chart_func = (
create_radar_chart
if selected_chart_type == "radar"
else create_bar_chart
)
chart = dcc.Graph(
figure=chart_func(data, "accuracy", categories),
style={"height": "1000px", "width": "100%"},
)
return [
html.Div(chart, style={"height": "1000px", "width": "100%"}),
html.Div(style={"height": "1000px", "width": "100%"}),
]
@app.callback(
Output("main-title", "children"), [Input("task-selector", "value")]
)
def update_title(selected_task):
return (
"ACTION OF INTEREST BENCHMARK"
if selected_task == "action_of_interest"
else (
"ACTION RECOGNITION BENCHMARK"
if selected_task == "action_recognition"
else "BENCHMARK"
)
)
# Run the app
if __name__ == "__main__":
app.run_server(debug=True)
Editor is loading...
Leave a Comment