find_best.py
unknown
python
a year ago
2.7 kB
20
Indexable
import os
import json
from collections import defaultdict, Counter
# Define folder containing prediction files
folder = 'preds'
files = os.listdir(folder)
category_dict = {}
with open('questions.json', 'r', encoding='utf-8') as f:
questions = json.load(f)['questions']
for question in questions:
qid = question['qid']
category = question['category']
category_dict[qid] = category
answer_sheets = {}
for file in files:
file_path = os.path.join(folder, file)
pred_dict = {}
with open(file_path, 'r') as f:
data = json.load(f)
for idx, answer in enumerate(data['answers']):
qid = answer['qid']
try:
retrieve = answer['retrieve'][0]
except:
retrieve = answer['retrieve']
if qid != idx + 1:
raise ValueError(f"File {file}: ID sequence error at qid {qid}")
pred_dict[qid] = retrieve
answer_sheets[file] = pred_dict
score_dict_insurance = {}
score_dict_finance = {}
score_dict_faq = {}
for file in files:
score_dict_insurance[file] = 0
score_dict_finance[file] = 0
score_dict_faq[file] = 0
with open("ground_truth.json", "r") as file:
answers = json.load(file)['answers']
for answer in answers:
qid = answer['qid']
ground_truth = answer['retrieve']
for file in files:
if answer_sheets[file][qid] == ground_truth:
if category_dict[qid] == 'insurance':
score_dict_insurance[file] += 1
elif category_dict[qid] == 'finance':
score_dict_finance[file] += 1
elif category_dict[qid] == 'faq':
score_dict_faq[file] += 1
insurance_score_list = [ (file, score) for file, score in score_dict_insurance.items() ]
score_list = sorted(insurance_score_list, key=lambda x: x[1], reverse=True)
finance_score_list = [ (file, score) for file, score in score_dict_finance.items() ]
score_list += sorted(finance_score_list, key=lambda x: x[1], reverse=True)
faq_score_list = [ (file, score) for file, score in score_dict_faq.items() ]
score_list += sorted(faq_score_list, key=lambda x: x[1], reverse=True)
print(f'[Insurance] Score list: {insurance_score_list}')
insurance_best_file = insurance_score_list[0][0]
print(f"[Insurance] Best prediction: {insurance_best_file}")
print('-' * 50)
print(f'[Finance] Score list: {finance_score_list}')
finance_best_file = finance_score_list[0][0]
print(f"[Finance] Best prediction: {finance_best_file}")
print('-' * 50)
print(f'[FAQ] Score list: {faq_score_list}')
faq_best_file = faq_score_list[0][0]
print(f"[FAQ] Best prediction: {faq_best_file}")Editor is loading...
Leave a Comment