face_eval.py(quoc14)

 avatar
unknown
python
5 months ago
12 kB
2
Indexable
import argparse
from pathlib import Path
from config import get_config
import numpy as np
import matplotlib.pyplot as plt
import os.path
import timeit
from helper import *
from sev import *
from tqdm import tqdm
import glob
import json

class FaceEval:
    
    def __init__(self, conf, is_update = True, number_image_use = 5):
        self.conf = conf
        self.number_image_use = number_image_use
        
        self.embeddings, self.names = self.update_face_data()
        print('facebank updated')
        self.error=[]
        

    def update_face_data(self):
        with open('VTS/mapUserID_enroll.json') as json_file:
            self.mapUserID_enroll = json.load(json_file)
        embeddings =  []
        names = ['Unknow']
        for x in self.mapUserID_enroll:  
            embs = []
            list_embs = glob.glob(self.conf.vts_enroll +'/'+ self.mapUserID_enroll[x] + '/*.emb')
            
            for emb_path in list_embs:        
                if (os.path.isfile(emb_path)):
                    emb = load_embedding(emb_path)                                     
                    embs.append(emb)

                if len(embs) == self.number_image_use: 
                    break
            
            if len(embs) != 0:
                embedding = np.array(embs).mean(0)
                embeddings.append(embedding)
                names.append(x)
        
        embeddings = np.array(embeddings)
        names = np.array(names)
        return embeddings, names

    def search_euclid(self, emb, threshold_list):
        search_result = {}
        emb = np.array([emb])
        sq_dist = np.sum((self.embeddings - emb) ** 2, axis=-1) # faster than np.linalg.norm()
        index_min = np.argmin(sq_dist) #find minimum distance
        dist_min = np.linalg.norm(self.embeddings[index_min] - np.array([emb]))
        for thres in threshold_list:
            idx_min = index_min if (dist_min <= thres) else -1
            search_result[thres] = [idx_min, dist_min]
        return search_result

    def search_cosine(self, emb, threshold_list):
        """
            Searching for the most similar image in facebank with a given embedding, optimized for multiple thresholds. (Cosine similarity)

            Args:
                emd             (array)         : embedding of an image.
                threshold_list  (list)          : list of threshold.

            Returns:
                search_result   (dictionary)    : dictionary contains record of each threshold search result.
        """
        search_result = {}

        disCos = (self.embeddings @ np.array([emb]).T)
        index_max = np.argmax(disCos)
        for thres in threshold_list:
            idx_max = index_max if (disCos[index_max] >=  thres) else -1
            search_result[thres] = [idx_max, disCos[idx_max]]
        return search_result

    def evaluate_K2(self, threshold_list):
        id_total = 0
        k2_result = {}
        for thre in threshold_list:
            k2_result[thre] = {
                "id_correct" : 0,
                "id_wrong" : 0,
                "id_unknown" : 0
            }    
        self.known_set_ids = []
        for key in self.mapUserID_enroll:
            self.known_set_ids.append(key)
                
        for x in tqdm(self.known_set_ids):
            link_embs = glob.glob(self.conf.vts_inout + '/' + x + '/*.emb')
            for emb_path in link_embs:
                emb = load_embedding(emb_path)
                if self.conf.distance == 'cosine': 
                    results = self.search_cosine(emb, threshold_list)
                else : 
                    results =  self.search_euclid(emb, threshold_list) 
                id_total += 1
                for thre in threshold_list:
                    if results[thre][0] < 0:
                        k2_result[thre]["id_unknown"] += 1
                        if(len(threshold_list) == 1):
                            self.error.append('--unknow: ' + x + 
                                              '--dis: ' + str(results[thre][1]) + 
                                              '--path: ' + emb_path[len(self.conf.vts_inout):])
                    elif (self.names[results[thre][0] + 1] ==  x):
                        k2_result[thre]["id_correct"] += 1
                    else:
                        k2_result[thre]["id_wrong"] += 1
                        if(len(threshold_list) == 1):
                            self.error.append('--id_wrong: ' + self.names[results[thre][0] + 1] +
                                               ' --label: ' + x + 
                                               ' --dis:' + str(results[thre][1]) + 
                                               ' --path: ' + emb_path[len(self.conf.vts_inout):])
        return id_total, k2_result
    
    def evaluate_U(self, threshold_list):
       
        id_total = 0
        u_result = {}
        for thre in threshold_list:
            u_result[thre] = {
                "id_correct" : 0,
                "id_wrong" : 0
            }

        #unknown_set_ids = read_from_file(str(self.conf.faceText_path)+'/unknown_set.txt')
        user_inout = glob.glob(self.conf.vts_inout+'/*')

        for user in tqdm(user_inout):
            user_id = user.split('\\')[-1].split('/')[-1]
            if user_id in self.known_set_ids:
                continue
            link_embs = glob.glob(user+'/*.emb')
            for emb_path in link_embs:            
                emb = load_embedding(emb_path)    
                if self.conf.distance == 'cosine': 
                    results = self.search_cosine(emb, threshold_list)
                else : 
                    results =  self.search_euclid(emb, threshold_list) 
                id_total += 1
                for thre in threshold_list:
                    if (results[thre][0] < 0):
                        u_result[thre]["id_correct"] += 1
                    else:
                        u_result[thre]["id_wrong"] += 1
                        if(len(threshold_list) == 1):
                            self.error.append('--evaluate_U_wrong: ' + self.names[results[thre][0] + 1] + 
                                              '--label: ' + user_id + 
                                            '--dis: ' + str(results[thre][1]) + 
                                            '--path: ' + emb_path[len(self.conf.vts_inout):])
        return id_total, u_result
    
    def evaluate_vts(self, threshold_list):
        id_total = 0
        result = {}
        for thre in threshold_list:
            result[thre] = {
                "id_correct" : 0, 
                "id_wrong" : 0, 
                "id_unknown" : 0 
            }    
        user_inout = glob.glob(self.conf.vts_inout+'/*')        

        for x in tqdm(user_inout):
            list_embs = glob.glob(x+'/*.emb')
            for emb_path in list_embs:
                emb = load_embedding(emb_path)
                if self.conf.distance == 'cosine': 
                    results = self.search_cosine(emb, threshold_list)
                else : 
                    results =  self.search_euclid(emb, threshold_list) 
                id_total += 1
                for thre in threshold_list:
                    if (self.names[results[thre][0] + 1] ==  x[0]):
                        result[thre]["id_correct"] += 1
                    elif (self.names[results[thre][0] + 1] == 'Unknown'):
                        result[thre]["id_unknown"] += 1
                    else:
                        result[thre]["id_wrong"] += 1
                    
        return id_total, result

    def draw(self, FPIR, FNIR, ACC, THRESHOLD, title, name_save_img):
        rect = 0.1, 0.1, 0.8, 0.8
        fig = plt.figure()
        ax1 = fig.add_axes(rect)
        plt.title(title)
        plt.yticks(np.array(FPIR))
        line_FPIR = ax1.plot(THRESHOLD, FPIR, 'b-', label = 'FPIR') # Put your speed/power plot here
        ax1.set_xlabel('threshold', color='b')

        line_FNIR = ax1.plot(THRESHOLD, FNIR, 'r-', label = 'FNIR') # Put your speed/power plot here
        ax1.set_ylabel('FPIR(blue) && FNIR(red)')
        ax2 = fig.add_axes(rect, frameon=False)
        ax2.yaxis.tick_right()
        ax2.yaxis.set_label_position('right')

        line_ACCC = ax2.plot(THRESHOLD, ACC, 'y-', label="ACC") # Put your speed/power plot here
        ax2.set_ylabel('ACC', color='y')
        lns = line_FPIR +line_FNIR +line_ACCC
        labs = [l.get_label() for l in lns]
        plt.legend(lns, labs, loc=0)
        plt.savefig(name_save_img)

    def open_set_evaluation(self, threshold_list):
       
        id_total1, k2_result = self.evaluate_K2(threshold_list)
        id_total2, u_result = self.evaluate_U(threshold_list)
        if len(threshold_list) == 1:
            f = open(str(self.conf.faceText_path)+'/open_set_results_'+ self.conf.method +'_error.txt', "w")
            for err in self.error:
                f.write(err + '\n')
            f.close()
        else:
            fpir_list = []
            fnir_list = []
            acc_list = []
            f = open(str(self.conf.faceText_path)+'/open_set_results_'+ self.conf.method +'.txt', "w")
            for thre in threshold_list:
                f.write('Threshold: {:.3f}\n'.format(thre))
                f.write('Total1 - Correct1 - Unknown1 - Wrong1 : {:n} - {:n} - {:n} - {:n}\n'.format(id_total1, k2_result[thre]["id_correct"], k2_result[thre]["id_unknown"], k2_result[thre]["id_wrong"]))
                f.write('Total2 - Correct2 -  Wrong2 : {:n} - {:n} - {:n}\n'.format(id_total2, u_result[thre]["id_correct"], u_result[thre]["id_wrong"]))
                FNIR = (k2_result[thre]["id_unknown"] + k2_result[thre]["id_wrong"]) / id_total1
                FPIR = u_result[thre]["id_wrong"] / id_total2
                ACC = (k2_result[thre]["id_correct"] + u_result[thre]["id_correct"]) / (id_total1 + id_total2)
                print('Threshold: ', thre)
                print('FPIR = ', FPIR)
                print('FNIR = ', FNIR)
                print('ACC = ', ACC)
                print("-"*30)
                f.write('FPIR: '+ '{:.5f}\n'.format(FPIR)) 
                f.write('FNIR: '+ '{:.5f}\n'.format(FNIR)) 
                f.write('ACC: '+ '{:.5f}\n'.format(ACC))
                f.write("-"*30 + '\n')
                acc_list.append(ACC)
                fpir_list.append(FPIR)
                fnir_list.append(FNIR)
                ', '.join(['%.5f']*len(threshold_list)) % tuple(threshold_list)
            f.write('\nThreshold list:\t' + ', '.join(['%.5f']*len(threshold_list)) % tuple(threshold_list))
            f.write('\nAccuracy list:\t' + ', '.join(['%.5f']*len(acc_list)) % tuple(acc_list))
            f.write('\nFPIR list:\t\t' + ', '.join(['%.5f']*len(fpir_list)) % tuple(fpir_list))
            f.write('\nFNIR list:\t\t' + ', '.join(['%.5f']*len(fnir_list)) % tuple(fnir_list))
            f.close()
            
            self.draw(fpir_list, fnir_list, acc_list, threshold_list, self.conf.method + '_' + self.conf.distance,  
                    str(self.conf.faceText_path)+'/open_set_results_'+ self.conf.method)
       
if __name__ == '__main__':
    start = timeit.default_timer()
    parser = argparse.ArgumentParser(description='for face verification')
    parser.add_argument('-th','--threshold',help='threshold to decide identical faces',default=0.3, type=float)
    parser.add_argument("-u", "--update", help="whether perform update the facebank", default = True, action="store_true")
    args = parser.parse_args()

    conf = get_config(False)
    
    faceEval = FaceEval(conf, args.update, conf.number_image_use)
    threshold_list = np.arange(0.5, 1.3, 0.025).tolist()
    #threshold_list = np.array([0.9])
    faceEval.open_set_evaluation(threshold_list)
    print('Time: ', timeit.default_timer() - start)
Editor is loading...
Leave a Comment