Untitled
unknown
plain_text
19 days ago
20 kB
2
Indexable
Never
import dash import json import base64 import tempfile import hashlib from dash import dcc, html from dash.dependencies import Input, Output, State import plotly.graph_objects as go import random from typing import Dict, List from flash_core.video_evaluation.htx.evaluation_pipeline import ( HTXEvaluationPipeline, ) # Constants CATEGORIES_ACTION_OF_INTEREST = [ "falling", "fighting", "self-harm", "none", "all_actions", ] COLORS = [ "#FF0000", "#00FF00", "#0000FF", "#FFA500", "#800080", "#00FFFF", "#FFD700", "#008080", "#800000", "#008000", ] htx_pipeline = HTXEvaluationPipeline() def get_metrics_result_from_vendor_prediction(results_prediction_path) -> Dict: action_of_interest_results = htx_pipeline.evaluate( results_prediction_path, task="action_of_interest" )["all_video"]["result"] action_recognition_results = htx_pipeline.evaluate( results_prediction_path, task="action_recognition" )["all_video"]["result"] detection_rate = { action: action_of_interest_results[action]["detection_rate"] * 100 for action in action_of_interest_results } false_positive_rate = { action: action_of_interest_results[action]["false_positive_rate"] * 100 for action in action_of_interest_results } accuracy = { action: action_recognition_results[action]["accuracy"] * 100 for action in action_recognition_results } return detection_rate, false_positive_rate, accuracy # Chart creation functions def create_radar_chart(data, data_type, categories): fig = go.Figure() # Add dashed lines from center to vertices for category in categories: fig.add_trace( go.Scatterpolar( r=[0, 100], theta=[category, category], mode="lines", line=dict(color="gray", width=1, dash="dash"), showlegend=False, hoverinfo="skip", ) ) # Add data traces with custom colors for vendor_data, color in zip(data, COLORS): values = [vendor_data[data_type][cat] for cat in categories] fig.add_trace( go.Scatterpolar( r=values + values[:1], theta=categories + [categories[0]], name=vendor_data["name"], fill="toself", line=dict(color=color, width=3), opacity=0.7, hovertemplate=( "<b>%{theta}</b><br>" f"{data_type.replace('_', ' ').title()}: %{{r}}<br>" # Updated label # "Action: %{theta}" # Custom angular label "<extra></extra>" ), ) ) # Update layout title = ( "Accuracy" if data_type == "accuracy" else ( "Detection Rate" if data_type == "detection_rate" else "False Positive Rate" ) ) annotation = ( "Higher is better" if data_type in ["detection_rate", "accuracy"] else "Lower is better" ) fig.update_polars(angularaxis_type="category") fig.update_layout( polar=dict( radialaxis=dict( visible=True, range=[0, 100], tickfont=dict(size=14), tickangle=30, tickvals=[20, 40, 60, 80, 100], ticktext=["20", "40", "60", "80", "100"], ), angularaxis=dict( tickfont=dict(size=22, color="red"), rotation=45, direction="clockwise", ), bgcolor="rgba(240, 240, 240, 0.8)", gridshape="circular", ), showlegend=True, legend=dict( font=dict(size=22), orientation="h", yanchor="bottom", y=-0.1, xanchor="center", x=0.5, ), title=dict( text=f"{title} - Performance Comparison", font=dict(size=30, family="Arial", color="black"), x=0.5, y=0.96, ), height=1000, margin=dict(t=100, b=80, l=80, r=80), ) fig.add_annotation( x=0.7, y=-0.03, xref="paper", yref="paper", text=annotation, showarrow=False, font=dict(size=18), bgcolor="white", bordercolor="black", borderwidth=1, borderpad=4, opacity=0.8, ) return fig def create_bar_chart(data, data_type, categories): fig = go.Figure() for i, vendor_data in enumerate(data): values = [vendor_data[data_type].get(cat, 0) for cat in categories] fig.add_trace( go.Bar( x=categories, y=values, name=vendor_data["name"], marker_color=COLORS[i % len(COLORS)], opacity=0.7, ) ) title = ( "Accuracy" if data_type == "accuracy" else ( "Detection Rate" if data_type == "detection_rate" else "False Positive Rate" ) ) annotation = ( "Higher is better" if data_type in ["detection_rate", "accuracy"] else "Lower is better" ) fig.update_layout( title=dict( text=f"{title} - Performance Comparison", font=dict(size=30, family="Arial", color="black"), x=0.5, y=0.96, ), xaxis=dict( title="Categories", titlefont=dict(size=24), tickfont=dict(size=14) ), yaxis=dict( title=title, titlefont=dict(size=16), tickfont=dict(size=14), range=[0, 100], tickvals=[20, 40, 60, 80, 100], ticktext=["20", "40", "60", "80", "100"], ), barmode="group", bargap=0.25, bargroupgap=0.05, legend=dict( orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5, font=dict(size=22), ), annotations=[ dict( x=0.85, y=-0.1, xref="paper", yref="paper", text=annotation, showarrow=False, bgcolor="white", bordercolor="black", borderwidth=1, borderpad=4, font=dict(size=18), ) ], height=1000, margin=dict(t=120, b=80, l=0, r=80), plot_bgcolor="rgba(240, 240, 240, 0.8)", ) fig.update_xaxes( tickfont=dict(color="black"), tickvals=categories, ticktext=[ f'<span style="color: red; font-weight: bold;">{cat}</span>' for cat in categories ], ) fig.update_yaxes(gridcolor="lightgray", gridwidth=1, griddash="dash") fig.update_xaxes(gridcolor="lightgray", gridwidth=1, griddash="dash") return fig # App layout components def create_sidebar(): return html.Div( [ dcc.Upload( id="upload-data", children=html.Div( [ html.P( "Upload Files \n or \n Drag and Drop\nhere", style={ "whiteSpace": "pre-line", "color": "#007bff", "fontWeight": "bold", }, ) ] ), style={ "width": "88%", "height": "150px", "lineHeight": "30px", "borderWidth": "3px", "borderStyle": "groove", "borderColor": "#bf1f24", "borderRadius": "20px", "textAlign": "center", "backgroundColor": "#ffffff", "margin": "10px", }, multiple=True, ), html.H2("Ingested Files"), html.Ul( id="file-list", style={ "listStyleType": "none", "fontSize": "18px", "padding": "0", "backgroundColor": "#e9ecef", "borderRadius": "10px", "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.1)", "lineHeight": "30px", }, ), html.H2("Select Task"), dcc.Dropdown( id="task-selector", options=[ { "label": "Action of Interest", "value": "action_of_interest", }, { "label": "Action Recognition", "value": "action_recognition", }, ], value="action_of_interest", style={"marginBottom": "20px"}, ), ], style={ "width": "9%", "background-color": "#f2f2f2", "padding": "2px", "position": "fixed", "height": "100%", "overflow": "auto", }, ) def create_main_content(): return html.Div( [ html.H1( id="main-title", style={ "textAlign": "center", "fontSize": "40px", "marginBottom": "20px", "marginLeft": "10%", }, ), html.Div( dcc.Dropdown( id="chart-type-dropdown", options=[ {"label": "Radar Chart", "value": "radar"}, {"label": "Bar Chart", "value": "bar"}, ], value="radar", style={"width": "50%", "marginLeft": "28%"}, ) ), html.Div(id="action-of-interest-data", style={"display": "none"}), html.Div(id="action-recognition-data", style={"display": "none"}), html.Div( id="chart-container", style={ "marginLeft": "10%", "display": "flex", "flexWrap": "wrap", }, ), ] ) # Initialize the Dash app app = dash.Dash(__name__) # Add this to store uploaded filenames app.layout = html.Div( [ dcc.Store(id="uploaded-filenames", storage_type="memory"), dcc.Store(id="file-status", storage_type="memory"), create_sidebar(), create_main_content(), ] ) @app.callback( Output("file-list", "children"), Output("uploaded-filenames", "data"), Output("file-status", "data"), [Input("upload-data", "filename")], [State("uploaded-filenames", "data"), State("file-status", "data")], ) def update_file_list(new_filenames, existing_filenames, existing_status): if existing_filenames is None: existing_filenames = [] if existing_status is None: existing_status = {} if new_filenames: updated_filenames = existing_filenames + new_filenames unique_filenames = list( dict.fromkeys(updated_filenames) ) # Remove duplicates while preserving order # Initialize new files with 'processing' status for filename in new_filenames: if filename not in existing_status: existing_status[filename] = "processing" return ( [ html.Li( filename, style={ "color": ( "orange" if existing_status[filename] == "processing" else ( "green" if existing_status[filename] == "success" else "red" ) ) }, ) for filename in unique_filenames ], unique_filenames, existing_status, ) if not existing_filenames: return [html.Li("No files uploaded yet")], [], {} return ( [ html.Li( filename, style={ "color": ( "orange" if existing_status[filename] == "processing" else ( "green" if existing_status[filename] == "success" else "red" ) ) }, ) for filename in existing_filenames ], existing_filenames, existing_status, ) @app.callback( Output("action-of-interest-data", "children"), Output("action-recognition-data", "children"), Output("file-status", "data", allow_duplicate=True), [Input("upload-data", "contents")], [ State("upload-data", "filename"), State("action-of-interest-data", "children"), State("action-recognition-data", "children"), State("file-status", "data"), ], prevent_initial_call=True, ) def update_output( new_contents, new_filenames, existing_interest_data, existing_recognition_data, file_status, ): if not new_contents: return existing_interest_data, existing_recognition_data, file_status existing_interest_data = ( json.loads(existing_interest_data) if existing_interest_data else [] ) existing_recognition_data = ( json.loads(existing_recognition_data) if existing_recognition_data else [] ) for contents, filename in zip(new_contents, new_filenames): try: content_type, content_string = contents.split(",") decoded = base64.b64decode(content_string) content_hash = hashlib.md5(decoded).hexdigest() temp_filename = f"temp_{content_hash}.json" with tempfile.NamedTemporaryFile( mode="wb", delete=False, suffix=".json", prefix="temp_" ) as temp_file: temp_file.write(decoded) temp_file_path = temp_file.name result = htx_pipeline.evaluate( vllm_results_path=temp_file_path, task="action_of_interest" ) print("Evaluation result:", result) detection_rate, false_positive_rate, accuracy = ( get_metrics_result_from_vendor_prediction(temp_file_path) ) vendor_name = filename.split(".")[0] # Check if vendor already exists in the data existing_vendor_interest = next( ( item for item in existing_interest_data if item["name"] == vendor_name ), None, ) if existing_vendor_interest: existing_vendor_interest.update( { "detection_rate": detection_rate, "false_positive_rate": false_positive_rate, } ) else: existing_interest_data.append( { "name": vendor_name, "detection_rate": detection_rate, "false_positive_rate": false_positive_rate, } ) existing_vendor_recognition = next( ( item for item in existing_recognition_data if item["name"] == vendor_name ), None, ) if existing_vendor_recognition: existing_vendor_recognition.update( { "accuracy": accuracy, } ) else: existing_recognition_data.append( { "name": vendor_name, "accuracy": accuracy, } ) file_status[filename] = "success" except Exception as e: print(f"There was an error processing file {filename}: {e}") # Mark file as failed file_status[filename] = "failed" return ( json.dumps(existing_interest_data), json.dumps(existing_recognition_data), file_status, ) # Update the main callback @app.callback( Output("chart-container", "children"), [ Input("task-selector", "value"), Input("chart-type-dropdown", "value"), Input("action-of-interest-data", "children"), Input("action-recognition-data", "children"), ], ) def update_charts( selected_task, selected_chart_type, action_of_interest_data, action_recognition_data, ): if selected_task == "action_of_interest": data = ( json.loads(action_of_interest_data) if action_of_interest_data else [] ) categories = CATEGORIES_ACTION_OF_INTEREST chart_func = ( create_radar_chart if selected_chart_type == "radar" else create_bar_chart ) chart1 = dcc.Graph( figure=chart_func(data, "detection_rate", categories), style={"height": "1000px", "width": "100%"}, ) chart2 = dcc.Graph( figure=chart_func(data, "false_positive_rate", categories), style={"height": "1000px", "width": "100%"}, ) return [ html.Div(style={"height": "10px"}), html.Div(chart1, style={"width": "100%"}), html.Div(style={"height": "1300px"}), html.Div(chart2, style={"width": "100%"}), html.Div(style={"height": "1200px"}), ] elif selected_task == "action_recognition": data = ( json.loads(action_recognition_data) if action_recognition_data else [] ) if data: # Find categories that are present in all vendor data all_categories = set(data[0]["accuracy"].keys()) for vendor_data in data[1:]: all_categories.intersection_update( vendor_data["accuracy"].keys() ) categories = sorted(list(all_categories)) else: categories = [] chart_func = ( create_radar_chart if selected_chart_type == "radar" else create_bar_chart ) chart = dcc.Graph( figure=chart_func(data, "accuracy", categories), style={"height": "1000px", "width": "100%"}, ) return [ html.Div(chart, style={"height": "1000px", "width": "100%"}), html.Div(style={"height": "1000px", "width": "100%"}), ] @app.callback( Output("main-title", "children"), [Input("task-selector", "value")] ) def update_title(selected_task): return ( "ACTION OF INTEREST BENCHMARK" if selected_task == "action_of_interest" else ( "ACTION RECOGNITION BENCHMARK" if selected_task == "action_recognition" else "BENCHMARK" ) ) # Run the app if __name__ == "__main__": app.run_server(debug=True)
Leave a Comment