Untitled

 avatar
unknown
plain_text
2 years ago
3.4 kB
13
Indexable
    def _compute_best_threshold(
        cls,
        metric: AggregatedBootstrappedLCBRecallAtSpecificityDefinition,
        aggr_labels: np.ndarray,
        aggr_scores: np.ndarray,
        weights: np.ndarray,
    ) -> Tuple[float, float, float]:

        cpu_to_use = np.max([np.array(os.cpu_count()) - 1, 1])
        mp.set_start_method("spawn", force=True)
        total_samples = len(aggr_labels)
        chunk_size = 20000
        epsilon = sys.float_info.epsilon
        min_specificity = metric.min_specificity

        n_changes = 0
        specificity_lcb = 1.0
        step = 0.05
        coef = 1.0
        iters = 0
        threshold = 0.81
        # If 0.81 is too low, it will take eight iterations to reach the maximum threshold
        # with each step of 0.025. In the case of non-monotonic functions, this is a fine enough grid
        # with enough space to explore. If it is too high, it will reach the bottom in no more than
        # 17 steps, bounce back + 2 iterations, bounce back + 2 steps and bounce back + 2 iterations.
        # In theory, 23 steps are enough in the worst case, hence condition iters < 24 on line 196
        # as n additional insurance. The first situation is more realistic, and I choose 0.81 to
        # fit this case. If everything is ok, it has to take to more than 8 + 2 +2 +2 steps,
        # with enough space for exploration
        final_results = []

        while (n_changes < 3 or specificity_lcb < min_specificity) and iters < 18:
            generator = (
                (
                    threshold,
                    [
                        (
                            aggr_labels[x + y],
                            aggr_scores[x + y],
                            weights[x + y],
                            x + y,
                        )
                        for x in range(chunk_size)
                        if x + y < total_samples
                    ],
                )
                for y in range(0, total_samples, chunk_size)
            )

            with mp.Pool(processes=cpu_to_use) as pool:
                stat_list = list(pool.imap(poisson_bootstrap_tp_fp_fn_tn, generator))

            TP, FP, FN, TN = np.sum(stat_list, 0)
            specificity = TN / (TN + FP + epsilon)
            recall = TP / (TP + FN + epsilon)
            recall_lcb = np.percentile(recall, 0.5)
            specificity_old = specificity_lcb
            specificity_lcb = np.percentile(specificity, 0.5)

            if specificity_lcb >= min_specificity and threshold < 1:
                final_results.append((specificity_lcb, recall_lcb, threshold))

            if (
                specificity_lcb < min_specificity and specificity_old < min_specificity
            ) or (
                specificity_lcb > min_specificity and specificity_old > min_specificity
            ):
                pass
            else:
                step /= 2.0
                coef *= -1.0
                n_changes += 1
            threshold -= coef * step
            iters += 1

        if len(final_results) == 0:
            specificity_lcb = 0
            recall_lcb = 0
            threshold = 1
        else:
            final_results = sorted(final_results, key=lambda x: x[1], reverse=True)
            specificity_lcb, recall_lcb, threshold = final_results[0]

        return specificity_lcb, recall_lcb, threshold
Editor is loading...