Untitled

 avatar
unknown
plain_text
5 months ago
4.2 kB
3
Indexable
import numpy as np
from sklearn.metrics import precision_score, recall_score
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import HoverTool, FuncTickFormatter, Arrow, NormalHead, Label

# Assuming you have 'target' (true labels) and 'prob' (predicted probabilities)
# Example:
# target = [1, 0, 1, 0, 1]
# prob = [0.9, 0.1, 0.8, 0.2, 0.7]

# Define thresholds
thresholds_1 = np.linspace(0.1, 0.0, 10)  # 10 values from 0.1 to 0.5
thresholds_2 = np.linspace(0.5, 0.9, 10)  # 10 values from 0.5 to 0.9
thresholds = np.concatenate([thresholds_1, [0.5], thresholds_2])  # Combine all thresholds

# Compute precision and recall at each threshold
precision_values = []
recall_values = []

for t in thresholds:
    preds = (prob >= t).astype(int)  # Binarize predictions based on threshold
    precision_values.append(precision_score(target, preds))
    recall_values.append(recall_score(target, preds))

# Convert precision and recall to percentages (multiply by 100)
precision_percent = [p * 100 for p in precision_values]
recall_percent = [r * 100 for r in recall_values]

# Find index of threshold 0.5 for highlighting
highlight_index = np.where(thresholds == 0.5)[0][0]
highlight_precision = precision_percent[highlight_index]
highlight_recall = recall_percent[highlight_index]
highlight_threshold = thresholds[highlight_index]

# Create a Bokeh plot
output_notebook()  # To display in Jupyter notebook

p = figure(title="Precision-Recall Curve (Custom Thresholds)",
           x_axis_label="Recall (%)",
           y_axis_label="Precision (%)",
           width=800, height=400)

# Add line for precision-recall curve (with no legend entry)
p.line(recall_percent, precision_percent, line_width=2, color="blue", alpha=0.7, legend_label=None)

# Add circle markers for specific thresholds with legend entries
p.circle(recall_percent, precision_percent, size=10, color="red", alpha=0.7, legend_label="Threshold Points")

# Highlight threshold at 0.5 with a different color and a legend entry
p.circle([highlight_recall], [highlight_precision], size=15, color="green", alpha=1,
         legend_label="Threshold 0.5", line_color="black")

# Add an arrow pointing to the threshold at 50%
arrow = Arrow(end=NormalHead(size=10, fill_color="green", line_color="green"),
              x_start=highlight_recall, y_start=highlight_precision,
              x_end=highlight_recall + 5, y_end=highlight_precision + 5,
              line_width=2, line_color="green")
p.add_glyph(arrow)

# Add text with precision and threshold at the arrow
text = Label(x=highlight_recall + 7, y=highlight_precision + 7,
             text=f"Threshold: {highlight_threshold:.2f}\nPrecision: {highlight_precision:.2f}%",
             text_font_size="12px", text_color="green", background_fill_alpha=0.4, background_fill_color="white")
p.add_glyph(text)

# Add tooltips to show details on hover, with precision and recall in percentage format
hover = HoverTool(tooltips=[
    ("Threshold", "@threshold"),
    ("Recall", "@x{0.00}%"),  # Format recall as percentage
    ("Precision", "@y{0.00}%")  # Format precision as percentage
])
p.add_tools(hover)

# Format the axes labels and ticks to display percentages with the % symbol
p.xaxis.formatter = FuncTickFormatter(code="""
    return tick + '%';
""")
p.yaxis.formatter = FuncTickFormatter(code="""
    return tick + '%';
""")

# Increase the size of the axis labels, make them bold and italic
p.xaxis.axis_label_text_font_size = "16px"  # Set x-axis label font size
p.yaxis.axis_label_text_font_size = "16px"  # Set y-axis label font size
p.xaxis.axis_label_text_font_style = "bold italic"  # Make x-axis label bold and italic
p.yaxis.axis_label_text_font_style = "bold italic"  # Make y-axis label bold and italic

# Add some additional styling
p.legend.location = "bottom_left"
p.legend.click_policy = "hide"
p.grid.grid_line_alpha = 0.3
p.title.text_font_size = "16px"

# Show the plot
show(p)

# Print precision and recall at threshold 0.5 for reference
print(f"Threshold 0.5: Precision = {highlight_precision:.2f}%, Recall = {highlight_recall:.2f}%")
Editor is loading...
Leave a Comment