Untitled

 avatar
unknown
plain_text
5 months ago
2.8 kB
8
Indexable
def load_and_evaluate(model_path, data):
    def calculate_accuracy(gold: np.ndarray, pred: np.ndarray) -> float:
        return np.mean(gold == pred)

    model = Word2Vec.load(model_path)
    preds = []
    golds = []
    for analogy in tqdm(data["Question"]):
        word_a, word_b, word_c, word_d = analogy.lower().split()
        preds.append(model.wv.most_similar(positive=[word_b, word_c], negative=[word_a], topn=1)[0][0])
        golds.append(word_d)

    golds_np, preds_np = np.array(golds), np.array(preds)
    accuracies = {'Category': {}, 'SubCategory': {}}
    # Evaluation: categories
    for category in data["Category"].unique():
        mask = data["Category"] == category
        golds_cat, preds_cat = golds_np[mask], preds_np[mask]
        acc_cat = calculate_accuracy(golds_cat, preds_cat)
        accuracies['Category'][category] = acc_cat * 100
    # Evaluation: sub-categories
    for sub_category in data["SubCategory"].unique():
        mask = data["SubCategory"] == sub_category
        golds_subcat, preds_subcat = golds_np[mask], preds_np[mask]
        acc_subcat = calculate_accuracy(golds_subcat, preds_subcat)
        accuracies['SubCategory'][sub_category] = acc_subcat * 100

    return accuracies


model_list = [('word2vec_4.model',  4), ('word2vec_9.model',  9),  ('word2vec_14.model', 14),
             ('word2vec_19.model', 19), ('word2vec_24.model', 24), ('word2vec_29.model', 29)]
category_accuracies = {}
sub_category_accuracies = {}
data = pd.read_csv(QUESTION_CSV_FILE)
for model_path, epoch in model_list:
    accuracies = load_and_evaluate(model_path, data)
    category_accuracies[epoch] = accuracies['Category']
    sub_category_accuracies[epoch] = accuracies['SubCategory']

plt.figure(figsize=(14, 7))
plt.subplot(1, 2, 1)
colors = cm.viridis(np.linspace(0, 1, len(data["Category"].unique())))
for i, category in enumerate(data["Category"].unique()):
    accuracies = [category_accuracies[epoch].get(category, 0) for epoch in sorted(category_accuracies.keys())]
    plt.plot(sorted(category_accuracies.keys()), accuracies, label=category, color=colors[i])
plt.title('Accuracy by Category over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
colors_subcat = cm.plasma(np.linspace(0, 1, len(data["SubCategory"].unique())))
for i, sub_category in enumerate(data["SubCategory"].unique()):
    accuracies = [sub_category_accuracies[epoch].get(sub_category, 0) for epoch in sorted(sub_category_accuracies.keys())]
    plt.plot(sorted(sub_category_accuracies.keys()), accuracies, label=sub_category, color=colors_subcat[i])
plt.title('Accuracy by Sub-Category over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()
Editor is loading...
Leave a Comment