Untitled
unknown
plain_text
a year ago
2.8 kB
11
Indexable
def load_and_evaluate(model_path, data):
def calculate_accuracy(gold: np.ndarray, pred: np.ndarray) -> float:
return np.mean(gold == pred)
model = Word2Vec.load(model_path)
preds = []
golds = []
for analogy in tqdm(data["Question"]):
word_a, word_b, word_c, word_d = analogy.lower().split()
preds.append(model.wv.most_similar(positive=[word_b, word_c], negative=[word_a], topn=1)[0][0])
golds.append(word_d)
golds_np, preds_np = np.array(golds), np.array(preds)
accuracies = {'Category': {}, 'SubCategory': {}}
# Evaluation: categories
for category in data["Category"].unique():
mask = data["Category"] == category
golds_cat, preds_cat = golds_np[mask], preds_np[mask]
acc_cat = calculate_accuracy(golds_cat, preds_cat)
accuracies['Category'][category] = acc_cat * 100
# Evaluation: sub-categories
for sub_category in data["SubCategory"].unique():
mask = data["SubCategory"] == sub_category
golds_subcat, preds_subcat = golds_np[mask], preds_np[mask]
acc_subcat = calculate_accuracy(golds_subcat, preds_subcat)
accuracies['SubCategory'][sub_category] = acc_subcat * 100
return accuracies
model_list = [('word2vec_4.model', 4), ('word2vec_9.model', 9), ('word2vec_14.model', 14),
('word2vec_19.model', 19), ('word2vec_24.model', 24), ('word2vec_29.model', 29)]
category_accuracies = {}
sub_category_accuracies = {}
data = pd.read_csv(QUESTION_CSV_FILE)
for model_path, epoch in model_list:
accuracies = load_and_evaluate(model_path, data)
category_accuracies[epoch] = accuracies['Category']
sub_category_accuracies[epoch] = accuracies['SubCategory']
plt.figure(figsize=(14, 7))
plt.subplot(1, 2, 1)
colors = cm.viridis(np.linspace(0, 1, len(data["Category"].unique())))
for i, category in enumerate(data["Category"].unique()):
accuracies = [category_accuracies[epoch].get(category, 0) for epoch in sorted(category_accuracies.keys())]
plt.plot(sorted(category_accuracies.keys()), accuracies, label=category, color=colors[i])
plt.title('Accuracy by Category over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
colors_subcat = cm.plasma(np.linspace(0, 1, len(data["SubCategory"].unique())))
for i, sub_category in enumerate(data["SubCategory"].unique()):
accuracies = [sub_category_accuracies[epoch].get(sub_category, 0) for epoch in sorted(sub_category_accuracies.keys())]
plt.plot(sorted(sub_category_accuracies.keys()), accuracies, label=sub_category, color=colors_subcat[i])
plt.title('Accuracy by Sub-Category over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()Editor is loading...
Leave a Comment