find_best.py

 avatar
unknown
python
6 months ago
2.7 kB
7
Indexable
import os
import json
from collections import defaultdict, Counter

# Define folder containing prediction files
folder = 'preds'
files = os.listdir(folder)

category_dict = {}
with open('questions.json', 'r', encoding='utf-8') as f:
    questions = json.load(f)['questions']
    for question in questions:
        qid = question['qid']
        category = question['category']
        category_dict[qid] = category

answer_sheets = {}
for file in files:
    file_path = os.path.join(folder, file)
    pred_dict = {}
    with open(file_path, 'r') as f:
        data = json.load(f)
        for idx, answer in enumerate(data['answers']):
            qid = answer['qid']
            try:
                retrieve = answer['retrieve'][0]
            except:
                retrieve = answer['retrieve']
            if qid != idx + 1:
                raise ValueError(f"File {file}: ID sequence error at qid {qid}")
            pred_dict[qid] = retrieve
    answer_sheets[file] = pred_dict

score_dict_insurance = {}
score_dict_finance = {}
score_dict_faq = {}
for file in files:
    score_dict_insurance[file] = 0
    score_dict_finance[file] = 0
    score_dict_faq[file] = 0
    
with open("ground_truth.json", "r") as file:
    answers = json.load(file)['answers']
for answer in answers:
    qid = answer['qid']
    ground_truth = answer['retrieve']
    for file in files:
        if answer_sheets[file][qid] == ground_truth:
            if category_dict[qid] == 'insurance':
                score_dict_insurance[file] += 1
            elif category_dict[qid] == 'finance':
                score_dict_finance[file] += 1
            elif category_dict[qid] == 'faq':
                score_dict_faq[file] += 1

insurance_score_list = [ (file, score) for file, score in score_dict_insurance.items() ]
score_list = sorted(insurance_score_list, key=lambda x: x[1], reverse=True)
finance_score_list = [ (file, score) for file, score in score_dict_finance.items() ]
score_list += sorted(finance_score_list, key=lambda x: x[1], reverse=True)
faq_score_list = [ (file, score) for file, score in score_dict_faq.items() ]
score_list += sorted(faq_score_list, key=lambda x: x[1], reverse=True)

print(f'[Insurance] Score list: {insurance_score_list}')
insurance_best_file = insurance_score_list[0][0]
print(f"[Insurance] Best prediction: {insurance_best_file}")

print('-' * 50)

print(f'[Finance] Score list: {finance_score_list}')
finance_best_file = finance_score_list[0][0]
print(f"[Finance] Best prediction: {finance_best_file}")

print('-' * 50)

print(f'[FAQ] Score list: {faq_score_list}')
faq_best_file = faq_score_list[0][0]
print(f"[FAQ] Best prediction: {faq_best_file}")
Editor is loading...
Leave a Comment