face_eval.py(quoc14)
unknown
python
5 months ago
12 kB
2
Indexable
import argparse from pathlib import Path from config import get_config import numpy as np import matplotlib.pyplot as plt import os.path import timeit from helper import * from sev import * from tqdm import tqdm import glob import json class FaceEval: def __init__(self, conf, is_update = True, number_image_use = 5): self.conf = conf self.number_image_use = number_image_use self.embeddings, self.names = self.update_face_data() print('facebank updated') self.error=[] def update_face_data(self): with open('VTS/mapUserID_enroll.json') as json_file: self.mapUserID_enroll = json.load(json_file) embeddings = [] names = ['Unknow'] for x in self.mapUserID_enroll: embs = [] list_embs = glob.glob(self.conf.vts_enroll +'/'+ self.mapUserID_enroll[x] + '/*.emb') for emb_path in list_embs: if (os.path.isfile(emb_path)): emb = load_embedding(emb_path) embs.append(emb) if len(embs) == self.number_image_use: break if len(embs) != 0: embedding = np.array(embs).mean(0) embeddings.append(embedding) names.append(x) embeddings = np.array(embeddings) names = np.array(names) return embeddings, names def search_euclid(self, emb, threshold_list): search_result = {} emb = np.array([emb]) sq_dist = np.sum((self.embeddings - emb) ** 2, axis=-1) # faster than np.linalg.norm() index_min = np.argmin(sq_dist) #find minimum distance dist_min = np.linalg.norm(self.embeddings[index_min] - np.array([emb])) for thres in threshold_list: idx_min = index_min if (dist_min <= thres) else -1 search_result[thres] = [idx_min, dist_min] return search_result def search_cosine(self, emb, threshold_list): """ Searching for the most similar image in facebank with a given embedding, optimized for multiple thresholds. (Cosine similarity) Args: emd (array) : embedding of an image. threshold_list (list) : list of threshold. Returns: search_result (dictionary) : dictionary contains record of each threshold search result. """ search_result = {} disCos = (self.embeddings @ np.array([emb]).T) index_max = np.argmax(disCos) for thres in threshold_list: idx_max = index_max if (disCos[index_max] >= thres) else -1 search_result[thres] = [idx_max, disCos[idx_max]] return search_result def evaluate_K2(self, threshold_list): id_total = 0 k2_result = {} for thre in threshold_list: k2_result[thre] = { "id_correct" : 0, "id_wrong" : 0, "id_unknown" : 0 } self.known_set_ids = [] for key in self.mapUserID_enroll: self.known_set_ids.append(key) for x in tqdm(self.known_set_ids): link_embs = glob.glob(self.conf.vts_inout + '/' + x + '/*.emb') for emb_path in link_embs: emb = load_embedding(emb_path) if self.conf.distance == 'cosine': results = self.search_cosine(emb, threshold_list) else : results = self.search_euclid(emb, threshold_list) id_total += 1 for thre in threshold_list: if results[thre][0] < 0: k2_result[thre]["id_unknown"] += 1 if(len(threshold_list) == 1): self.error.append('--unknow: ' + x + '--dis: ' + str(results[thre][1]) + '--path: ' + emb_path[len(self.conf.vts_inout):]) elif (self.names[results[thre][0] + 1] == x): k2_result[thre]["id_correct"] += 1 else: k2_result[thre]["id_wrong"] += 1 if(len(threshold_list) == 1): self.error.append('--id_wrong: ' + self.names[results[thre][0] + 1] + ' --label: ' + x + ' --dis:' + str(results[thre][1]) + ' --path: ' + emb_path[len(self.conf.vts_inout):]) return id_total, k2_result def evaluate_U(self, threshold_list): id_total = 0 u_result = {} for thre in threshold_list: u_result[thre] = { "id_correct" : 0, "id_wrong" : 0 } #unknown_set_ids = read_from_file(str(self.conf.faceText_path)+'/unknown_set.txt') user_inout = glob.glob(self.conf.vts_inout+'/*') for user in tqdm(user_inout): user_id = user.split('\\')[-1].split('/')[-1] if user_id in self.known_set_ids: continue link_embs = glob.glob(user+'/*.emb') for emb_path in link_embs: emb = load_embedding(emb_path) if self.conf.distance == 'cosine': results = self.search_cosine(emb, threshold_list) else : results = self.search_euclid(emb, threshold_list) id_total += 1 for thre in threshold_list: if (results[thre][0] < 0): u_result[thre]["id_correct"] += 1 else: u_result[thre]["id_wrong"] += 1 if(len(threshold_list) == 1): self.error.append('--evaluate_U_wrong: ' + self.names[results[thre][0] + 1] + '--label: ' + user_id + '--dis: ' + str(results[thre][1]) + '--path: ' + emb_path[len(self.conf.vts_inout):]) return id_total, u_result def evaluate_vts(self, threshold_list): id_total = 0 result = {} for thre in threshold_list: result[thre] = { "id_correct" : 0, "id_wrong" : 0, "id_unknown" : 0 } user_inout = glob.glob(self.conf.vts_inout+'/*') for x in tqdm(user_inout): list_embs = glob.glob(x+'/*.emb') for emb_path in list_embs: emb = load_embedding(emb_path) if self.conf.distance == 'cosine': results = self.search_cosine(emb, threshold_list) else : results = self.search_euclid(emb, threshold_list) id_total += 1 for thre in threshold_list: if (self.names[results[thre][0] + 1] == x[0]): result[thre]["id_correct"] += 1 elif (self.names[results[thre][0] + 1] == 'Unknown'): result[thre]["id_unknown"] += 1 else: result[thre]["id_wrong"] += 1 return id_total, result def draw(self, FPIR, FNIR, ACC, THRESHOLD, title, name_save_img): rect = 0.1, 0.1, 0.8, 0.8 fig = plt.figure() ax1 = fig.add_axes(rect) plt.title(title) plt.yticks(np.array(FPIR)) line_FPIR = ax1.plot(THRESHOLD, FPIR, 'b-', label = 'FPIR') # Put your speed/power plot here ax1.set_xlabel('threshold', color='b') line_FNIR = ax1.plot(THRESHOLD, FNIR, 'r-', label = 'FNIR') # Put your speed/power plot here ax1.set_ylabel('FPIR(blue) && FNIR(red)') ax2 = fig.add_axes(rect, frameon=False) ax2.yaxis.tick_right() ax2.yaxis.set_label_position('right') line_ACCC = ax2.plot(THRESHOLD, ACC, 'y-', label="ACC") # Put your speed/power plot here ax2.set_ylabel('ACC', color='y') lns = line_FPIR +line_FNIR +line_ACCC labs = [l.get_label() for l in lns] plt.legend(lns, labs, loc=0) plt.savefig(name_save_img) def open_set_evaluation(self, threshold_list): id_total1, k2_result = self.evaluate_K2(threshold_list) id_total2, u_result = self.evaluate_U(threshold_list) if len(threshold_list) == 1: f = open(str(self.conf.faceText_path)+'/open_set_results_'+ self.conf.method +'_error.txt', "w") for err in self.error: f.write(err + '\n') f.close() else: fpir_list = [] fnir_list = [] acc_list = [] f = open(str(self.conf.faceText_path)+'/open_set_results_'+ self.conf.method +'.txt', "w") for thre in threshold_list: f.write('Threshold: {:.3f}\n'.format(thre)) f.write('Total1 - Correct1 - Unknown1 - Wrong1 : {:n} - {:n} - {:n} - {:n}\n'.format(id_total1, k2_result[thre]["id_correct"], k2_result[thre]["id_unknown"], k2_result[thre]["id_wrong"])) f.write('Total2 - Correct2 - Wrong2 : {:n} - {:n} - {:n}\n'.format(id_total2, u_result[thre]["id_correct"], u_result[thre]["id_wrong"])) FNIR = (k2_result[thre]["id_unknown"] + k2_result[thre]["id_wrong"]) / id_total1 FPIR = u_result[thre]["id_wrong"] / id_total2 ACC = (k2_result[thre]["id_correct"] + u_result[thre]["id_correct"]) / (id_total1 + id_total2) print('Threshold: ', thre) print('FPIR = ', FPIR) print('FNIR = ', FNIR) print('ACC = ', ACC) print("-"*30) f.write('FPIR: '+ '{:.5f}\n'.format(FPIR)) f.write('FNIR: '+ '{:.5f}\n'.format(FNIR)) f.write('ACC: '+ '{:.5f}\n'.format(ACC)) f.write("-"*30 + '\n') acc_list.append(ACC) fpir_list.append(FPIR) fnir_list.append(FNIR) ', '.join(['%.5f']*len(threshold_list)) % tuple(threshold_list) f.write('\nThreshold list:\t' + ', '.join(['%.5f']*len(threshold_list)) % tuple(threshold_list)) f.write('\nAccuracy list:\t' + ', '.join(['%.5f']*len(acc_list)) % tuple(acc_list)) f.write('\nFPIR list:\t\t' + ', '.join(['%.5f']*len(fpir_list)) % tuple(fpir_list)) f.write('\nFNIR list:\t\t' + ', '.join(['%.5f']*len(fnir_list)) % tuple(fnir_list)) f.close() self.draw(fpir_list, fnir_list, acc_list, threshold_list, self.conf.method + '_' + self.conf.distance, str(self.conf.faceText_path)+'/open_set_results_'+ self.conf.method) if __name__ == '__main__': start = timeit.default_timer() parser = argparse.ArgumentParser(description='for face verification') parser.add_argument('-th','--threshold',help='threshold to decide identical faces',default=0.3, type=float) parser.add_argument("-u", "--update", help="whether perform update the facebank", default = True, action="store_true") args = parser.parse_args() conf = get_config(False) faceEval = FaceEval(conf, args.update, conf.number_image_use) threshold_list = np.arange(0.5, 1.3, 0.025).tolist() #threshold_list = np.array([0.9]) faceEval.open_set_evaluation(threshold_list) print('Time: ', timeit.default_timer() - start)
Editor is loading...
Leave a Comment