import numpy as np
import cv2 as cv
import glob
import os
import random
from skimage.feature import hog
from Parameters import Parameters
from sklearn.svm import LinearSVC
import pickle

# Inițializare parametri
params = Parameters()

# Harta fișier-adnotări -> folder
    './antrenare/dad_annotations.txt': 'dad',
    './antrenare/deedee_annotations.txt': 'deedee',
    './antrenare/dexter_annotations.txt': 'dexter',
    './antrenare/mom_annotations.txt': 'mom'

class FacialDetector:
    def __init__(self, params):
        self.params = params
        self.best_model = None

    def load_annotations(self, file_path):
        Încarcă adnotările dintr-un fișier: image_name x_min y_min x_max y_max
        Returnează un dict: {image_name: [(x_min,y_min,x_max,y_max), ...]}
        annotations = {}
        with open(file_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                image_name = parts[0]
                x_min, y_min, x_max, y_max = map(int, parts[1:5])
                if image_name not in annotations:
                    annotations[image_name] = []
                annotations[image_name].append((x_min, y_min, x_max, y_max))
        return annotations

    def analyze_svm_training(self, positive_features, negative_features):
        Afișează scoruri, histograme, coeficienți SVM.
        if self.best_model is None:
            print("SVM-ul nu a fost antrenat încă!")

        positive_scores = self.best_model.decision_function(positive_features)
        print(f"Scoruri pozitive (primele 10): {positive_scores[:10]}")

        negative_scores = self.best_model.decision_function(negative_features)
        print(f"Scoruri negative (primele 10): {negative_scores[:10]}")

        print("Coeficienți SVM (primele 10):")

        print(f"Bias-ul SVM: {self.best_model.intercept_[0]}")

        import matplotlib.pyplot as plt
        plt.hist(positive_scores, bins=30, alpha=0.5, label="Positive Scores")
        plt.hist(negative_scores, bins=30, alpha=0.5, label="Negative Scores")
        plt.axvline(x=0, color="red", linestyle="--", label="Decision Boundary")
        plt.title("Distribuția scorurilor SVM")

    def train_classifier(self, positive_features, negative_features):
        Căutăm cel mai bun C pentru SVM.
        Cs = [10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1, 10 ** 0, 10 ** 1]
        best_c = 0
        best_accuracy = 0
        best_model = None

        training_examples = np.vstack((positive_features, negative_features))
        train_labels = np.hstack((np.ones(positive_features.shape[0]), np.zeros(negative_features.shape[0])))

        print("Începem căutarea celui mai bun C...")

        for c in Cs:
            print(f"Antrenăm un clasificator pentru C = {c}")
            model = LinearSVC(C=c, max_iter=10000)
            model.fit(training_examples, train_labels)
            acc = model.score(training_examples, train_labels)
            print(f"Precizia pentru C = {c}: {acc}")

            if acc > best_accuracy:
                best_accuracy = acc
                best_c = c
                best_model = model

        self.best_model = best_model
        print(f"Cel mai bun C: {best_c} cu precizia: {best_accuracy}")

        model_path = os.path.join(self.params.dir_save_files, 'svm_model.pkl')
        with open(model_path, 'wb') as f:
            pickle.dump(best_model, f)
        print(f"Modelul optim SVM a fost salvat în {model_path}.")

    def generate_positive_patches(self, image_path, bounding_boxes):
        img = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
        positive_patches = []
        for (x_min, y_min, x_max, y_max) in bounding_boxes:
            patch = img[y_min:y_max, x_min:x_max]
            if patch is not None and patch.shape[0] > 0 and patch.shape[1] > 0:
                patch = cv.resize(patch, (self.params.dim_window, self.params.dim_window))
        return positive_patches

    def generate_negative_patches(self, image_path, bounding_boxes, num_negatives=10):
        Generează patch-uri negative. Asigurăm că nu se suprapune cu box-uri pozitive.
        img = cv.imread(image_path, cv.IMREAD_GRAYSCALE)
        if img is None:
            print(f"Eroare la încărcarea imaginii {image_path}.")
            return []

        img_h, img_w = img.shape
        negative_patches = []
        max_attempts = 50
        for _ in range(num_negatives):
            attempts = 0
            while attempts < max_attempts:
                x_min = random.randint(0, img_w - self.params.dim_window)
                y_min = random.randint(0, img_h - self.params.dim_window)
                x_max = x_min + self.params.dim_window
                y_max = y_min + self.params.dim_window

                intersects = False
                for (bx_min, by_min, bx_max, by_max) in bounding_boxes:
                    if (x_min < bx_max and x_max > bx_min and
                            y_min < by_max and y_max > by_min):
                        intersects = True
                if not intersects:
                    patch = img[y_min:y_max, x_min:x_max]
                    if patch is not None and patch.shape[0] > 0 and patch.shape[1] > 0:
                        patch = cv.resize(patch, (self.params.dim_window, self.params.dim_window))
                attempts += 1
        return negative_patches

    def get_positive_descriptors(self):
        Citește imagini + adnotări => extrage patch-uri pozitive => HOG.
        annotations_files = [
        positive_descriptors = []

        for annotation_file in annotations_files:
            if not os.path.exists(annotation_file):
                print(f"Fișierul {annotation_file} nu există!")

            folder = ANNOTATIONS_TO_FOLDER[annotation_file]
            print(f"[INFO] Trecem la fișierul de adnotări: {annotation_file}, pentru folderul {folder}")
            annotations = self.load_annotations(annotation_file)

            processed_images = set()
            for image_name, bounding_boxes in annotations.items():
                if image_name in processed_images:

                path = os.path.join('./antrenare', folder, image_name)
                if not os.path.exists(path):
                    print(f"[WARN] Imaginea {image_name} nu există în folderul {folder}!")

                positive_patches = self.generate_positive_patches(path, bounding_boxes)

                from skimage.feature import hog
                for patch in positive_patches:
                    hog_descriptor = hog(
                        pixels_per_cell=(self.params.dim_hog_cell, self.params.dim_hog_cell),
                        cells_per_block=(2, 2),


        print(f"[INFO] Număr total de descriptori pozitivi generați: {len(positive_descriptors)}")
        return np.array(positive_descriptors)

    def get_negative_descriptors(self):
        Citește imagini + adnotări => extrage patch-uri negative => HOG.
        annotations_files = [
        negative_descriptors = []

        for annotation_file in annotations_files:
            if not os.path.exists(annotation_file):

            folder = ANNOTATIONS_TO_FOLDER[annotation_file]
            annotations = self.load_annotations(annotation_file)

            processed_images = set()
            for image_name, bounding_boxes in annotations.items():
                if image_name in processed_images:

                path = os.path.join('./antrenare', folder, image_name)
                if not os.path.exists(path):

                negative_patches = self.generate_negative_patches(path, bounding_boxes, num_negatives=10)
                from skimage.feature import hog
                for patch in negative_patches:
                    hog_descriptor = hog(
                        pixels_per_cell=(self.params.dim_hog_cell, self.params.dim_hog_cell),
                        cells_per_block=(2, 2),


        print(f"[INFO] Număr total de descriptori negativi generați: {len(negative_descriptors)}")
        return np.array(negative_descriptors)

    def run(self):
        Detectarea fețelor folosind sliding window la MULTIPLE scale.
        test_images_path = os.path.join(self.params.dir_test_examples, '*.jpg')
        test_files = glob.glob(test_images_path)

        # liste pentru colectare
        detections_list = []
        scores_list = []
        file_names_list = []

        # definim scări (exemplu)
        scales = [0.4, 0.6, 0.8, 1.0, 1.3, 1.6, 2.0]

        for test_file in test_files:
            img = cv.imread(test_file, cv.IMREAD_GRAYSCALE)
            if img is None:
                print(f"Imaginea {test_file} nu a putut fi încărcată.")

            original_h, original_w = img.shape

            for scale in scales:
                new_w = int(original_w * scale)
                new_h = int(original_h * scale)
                if new_w < self.params.dim_window or new_h < self.params.dim_window:

                resized_img = cv.resize(img, (new_w, new_h))

                step = self.params.dim_window // 4
                from skimage.feature import hog
                for y in range(0, new_h - self.params.dim_window, step):
                    for x in range(0, new_w - self.params.dim_window, step):
                        patch = resized_img[y:y + self.params.dim_window, x:x + self.params.dim_window]
                        hog_descriptor = hog(
                            pixels_per_cell=(self.params.dim_hog_cell, self.params.dim_hog_cell),
                            cells_per_block=(2, 2),
                        score = self.best_model.decision_function([hog_descriptor])[0]
                        if score > self.params.threshold:
                            # transformăm coordonatele
                            x_min = int(x / scale)
                            y_min = int(y / scale)
                            x_max = int((x + self.params.dim_window) / scale)
                            y_max = int((y + self.params.dim_window) / scale)

                            detections_list.append([x_min, y_min, x_max, y_max])

        # convertim la np.array
        detections = np.array(detections_list)
        scores = np.array(scores_list)
        file_names = np.array(file_names_list)

        print("Aplicăm Non-Maximum Suppression...")
        final_boxes, final_scores, keep_idx = self.non_max_suppression(detections, scores, iou_threshold=0.5)
        final_file_names = file_names[keep_idx]

        print(f"Număr total de detecții după NMS: {len(final_boxes)}")
        return final_boxes, final_scores, final_file_names

    def eval_detections(self, detections, scores, file_names):
        Evaluează detectările și calculează AP.
        if not os.path.exists(self.params.path_annotations):
            print(f"Debug: Fișierul de adnotări nu există la calea: {self.params.path_annotations}")
            print(f"Debug: Fișierul de adnotări a fost găsit la calea: {self.params.path_annotations}")

        # 1) Citim fișierul de adnotări
        ground_truth = {}
        total_gt = 0
        with open(self.params.path_annotations, "r") as f:
            for line in f:
                parts = line.strip().split()
                file_name = parts[0]
                bbox = list(map(int, parts[1:5]))
                if file_name not in ground_truth:
                    ground_truth[file_name] = []
                total_gt += 1

        # 2) Parcurgem toate detecțiile sub formă (file_name, score, x1,y1,x2,y2)
        all_dets = []
        for i, fn in enumerate(file_names):
            box = detections[i]
            sc = scores[i]
            all_dets.append((fn, sc, box))

        # Sortăm global după scor, descrescător
        all_dets = sorted(all_dets, key=lambda x: x[1], reverse=True)

        iou_threshold = 0.3
        tp_array = []
        fp_array = []
        # Ținem evidența box-urilor GT deja asociate
        matched_gt = {}  # {file_name: [False, False, ...] pt fiecare box}
        for fn in ground_truth:
            matched_gt[fn] = [False] * len(ground_truth[fn])

        # 3) Marcare TP / FP
        for (fname, score, box_det) in all_dets:
            if fname not in ground_truth:
                # nu avem adnotări pt imaginea asta => FP
            gtlist = ground_truth[fname]
            best_iou = 0
            best_idx = -1
            for idx, gt_box in enumerate(gtlist):
                iou_val = self.compute_iou(box_det, gt_box)
                if iou_val > best_iou:
                    best_iou = iou_val
                    best_idx = idx
            if best_iou >= iou_threshold and best_idx != -1 and not matched_gt[fname][best_idx]:
                matched_gt[fname][best_idx] = True

        # 4) Calculăm precision-recall + AP
        tp_cum = np.cumsum(tp_array)
        fp_cum = np.cumsum(fp_array)
        recalls = tp_cum / float(total_gt)
        precisions = tp_cum / (tp_cum + fp_cum + 1e-10)
        ap = self.voc_ap(recalls, precisions)  # exemplu funcție voc_ap

        # 5) Afișăm metrici la ultimul punct (adică la threshold fix)
        # (dar e un alt threshold, practic pe scor...)
        # Totuși, True Pos/False Pos/False Neg pot fi interpretate la final
        final_tp = tp_cum[-1]
        final_fp = fp_cum[-1]
        final_fn = total_gt - final_tp

        precision = final_tp / float(final_tp + final_fp + 1e-10)
        recall = final_tp / float(total_gt + 1e-10)

        print(f"Evaluare finală:")
        print(f"True Positives: {int(final_tp)}")
        print(f"False Positives: {int(final_fp)}")
        print(f"False Negatives: {int(final_fn)}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"Average Precision: {ap:.4f}")

        # Salvăm
        results_path = os.path.join(self.params.dir_save_files, "evaluation_results.txt")
        with open(results_path, "w") as f:
            f.write(f"True Positives: {int(final_tp)}\n")
            f.write(f"False Positives: {int(final_fp)}\n")
            f.write(f"False Negatives: {int(final_fn)}\n")
            f.write(f"Precision: {precision:.4f}\n")
            f.write(f"Recall: {recall:.4f}\n")
            f.write(f"Average Precision: {ap:.4f}\n")
        print(f"Rezultatele evaluării au fost salvate în {results_path}.")

    def voc_ap(self, rec, prec):
        Computăm Average Precision folosind metoda "VOC2007 11-point interpolation"
        sau direct area sub PR. Aici facem "area sub curba" simplu:

        Metodă simplă: AP = sum((r[i] - r[i-1]) * p[i])
        # Adăugăm puncte inițiale (0,1)
        mrec = np.concatenate(([0.0], rec, [1.0]))
        mpre = np.concatenate(([1.0], prec, [0.0]))

        # Pentru a fi robust, transformăm mpre în envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = max(mpre[i - 1], mpre[i])

        # Apoi sumăm pe segmente unde recall crește
        i_idx = np.where(mrec[1:] != mrec[:-1])[0]
        ap = np.sum((mrec[i_idx + 1] - mrec[i_idx]) * mpre[i_idx + 1])
        return ap

    def non_max_suppression(self, boxes, scores, iou_threshold=0.3):
        if len(boxes) == 0:
            return boxes, scores, []

        order = np.argsort(scores)[::-1]
        keep = []
        while len(order) > 0:
            i = order[0]

            ious = []
            for j in order[1:]:
                iou_ij = self.compute_iou(boxes[i], boxes[j])

            ious = np.array(ious)
            inds = np.where(ious <= iou_threshold)[0]
            order = order[inds + 1]

        keep = np.array(keep, dtype=int)
        final_boxes = boxes[keep]
        final_scores = scores[keep]
        return final_boxes, final_scores, keep

    def compute_iou(self, boxA, boxB):
        xA = max(boxA[0], boxB[0])
        yA = max(boxA[1], boxB[1])
        xB = min(boxA[2], boxB[2])
        yB = min(boxA[3], boxB[3])

        interW = max(0, xB - xA)
        interH = max(0, yB - yA)
        interArea = interW * interH

        boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
        boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
        iou = interArea / float(boxAArea + boxBArea - interArea + 1e-10)
        return iou

# Inițializăm detectorul facial
facial_detector = FacialDetector(params)
