find_best.py
unknown
python
6 months ago
2.7 kB
7
Indexable
import os import json from collections import defaultdict, Counter # Define folder containing prediction files folder = 'preds' files = os.listdir(folder) category_dict = {} with open('questions.json', 'r', encoding='utf-8') as f: questions = json.load(f)['questions'] for question in questions: qid = question['qid'] category = question['category'] category_dict[qid] = category answer_sheets = {} for file in files: file_path = os.path.join(folder, file) pred_dict = {} with open(file_path, 'r') as f: data = json.load(f) for idx, answer in enumerate(data['answers']): qid = answer['qid'] try: retrieve = answer['retrieve'][0] except: retrieve = answer['retrieve'] if qid != idx + 1: raise ValueError(f"File {file}: ID sequence error at qid {qid}") pred_dict[qid] = retrieve answer_sheets[file] = pred_dict score_dict_insurance = {} score_dict_finance = {} score_dict_faq = {} for file in files: score_dict_insurance[file] = 0 score_dict_finance[file] = 0 score_dict_faq[file] = 0 with open("ground_truth.json", "r") as file: answers = json.load(file)['answers'] for answer in answers: qid = answer['qid'] ground_truth = answer['retrieve'] for file in files: if answer_sheets[file][qid] == ground_truth: if category_dict[qid] == 'insurance': score_dict_insurance[file] += 1 elif category_dict[qid] == 'finance': score_dict_finance[file] += 1 elif category_dict[qid] == 'faq': score_dict_faq[file] += 1 insurance_score_list = [ (file, score) for file, score in score_dict_insurance.items() ] score_list = sorted(insurance_score_list, key=lambda x: x[1], reverse=True) finance_score_list = [ (file, score) for file, score in score_dict_finance.items() ] score_list += sorted(finance_score_list, key=lambda x: x[1], reverse=True) faq_score_list = [ (file, score) for file, score in score_dict_faq.items() ] score_list += sorted(faq_score_list, key=lambda x: x[1], reverse=True) print(f'[Insurance] Score list: {insurance_score_list}') insurance_best_file = insurance_score_list[0][0] print(f"[Insurance] Best prediction: {insurance_best_file}") print('-' * 50) print(f'[Finance] Score list: {finance_score_list}') finance_best_file = finance_score_list[0][0] print(f"[Finance] Best prediction: {finance_best_file}") print('-' * 50) print(f'[FAQ] Score list: {faq_score_list}') faq_best_file = faq_score_list[0][0] print(f"[FAQ] Best prediction: {faq_best_file}")
Editor is loading...
Leave a Comment