Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
2.2 kB
4
Indexable
Never
def calc_auc(pred_bboxes, gt_bboxes, draw=False):
    """
    :param pred_bboxes: dict of bboxes in format {filename: detections}
        detections is a N x 5 array, where N is number of detections. Each
        detection is described using 5 numbers: [row, col, n_rows, n_cols,
        confidence].
    :param gt_bboxes: dict of bboxes in format {filenames: bboxes}. bboxes is a
        list of tuples in format (row, col, n_rows, n_cols)
    :return: auc measure for given detections and gt
    """
    # your code here \/
    from sklearn.metrics import auc
    tp = []
    fp = []
    for filename in pred_bboxes:
        # print(filename)
        preds = pred_bboxes[filename]
        if preds is None or preds.size == 0:
            continue
        preds = preds[preds[:, -1].argsort()[::-1]]
        gt = gt_bboxes[filename].copy()
        gt_bbox = 0
        for pred_bbox in preds:
            bbox_found = False
            for gt_bbox in gt:
                if calc_iou(pred_bbox[:-1], gt_bbox) > 0.5:
                    tp.append(pred_bbox)
                    bbox_found = True
                    break
            if bbox_found:
                gt.remove(gt_bbox)
            else:
                fp.append(pred_bbox)

    all_pos = np.array(tp + fp)
    all_pos = all_pos[all_pos[:, -1].argsort()]

    tp = np.array(tp)
    tp = tp[tp[:, -1].argsort()]

    all_gt_pos = 0
    for filename in gt_bboxes:
        all_gt_pos += len(gt_bboxes[filename])

    precisions = []
    recalls = []

    eps = 1e-4
    thresholds = [all_pos[0, -1] - eps] + list((all_pos[1:, -1] + all_pos[:-1, -1]) / 2) + [all_pos[-1, -1] + eps]
    for threshold in thresholds:
        pred_pos = all_pos[:, -1] >= threshold
        pred_tp = tp[:, -1] >= threshold
        if pred_pos.sum() > 0:
            precision = pred_tp.sum() / (pred_pos.sum())
        else:
            precision = 1

        recall = pred_tp.sum() / all_gt_pos

        precisions.append(precision)
        recalls.append(recall)
    if draw:
        plt.plot(recalls, precisions)

    return auc(recalls, precisions)
    # your code here /\