Untitled
unknown
plain_text
5 months ago
2.8 kB
8
Indexable
def load_and_evaluate(model_path, data): def calculate_accuracy(gold: np.ndarray, pred: np.ndarray) -> float: return np.mean(gold == pred) model = Word2Vec.load(model_path) preds = [] golds = [] for analogy in tqdm(data["Question"]): word_a, word_b, word_c, word_d = analogy.lower().split() preds.append(model.wv.most_similar(positive=[word_b, word_c], negative=[word_a], topn=1)[0][0]) golds.append(word_d) golds_np, preds_np = np.array(golds), np.array(preds) accuracies = {'Category': {}, 'SubCategory': {}} # Evaluation: categories for category in data["Category"].unique(): mask = data["Category"] == category golds_cat, preds_cat = golds_np[mask], preds_np[mask] acc_cat = calculate_accuracy(golds_cat, preds_cat) accuracies['Category'][category] = acc_cat * 100 # Evaluation: sub-categories for sub_category in data["SubCategory"].unique(): mask = data["SubCategory"] == sub_category golds_subcat, preds_subcat = golds_np[mask], preds_np[mask] acc_subcat = calculate_accuracy(golds_subcat, preds_subcat) accuracies['SubCategory'][sub_category] = acc_subcat * 100 return accuracies model_list = [('word2vec_4.model', 4), ('word2vec_9.model', 9), ('word2vec_14.model', 14), ('word2vec_19.model', 19), ('word2vec_24.model', 24), ('word2vec_29.model', 29)] category_accuracies = {} sub_category_accuracies = {} data = pd.read_csv(QUESTION_CSV_FILE) for model_path, epoch in model_list: accuracies = load_and_evaluate(model_path, data) category_accuracies[epoch] = accuracies['Category'] sub_category_accuracies[epoch] = accuracies['SubCategory'] plt.figure(figsize=(14, 7)) plt.subplot(1, 2, 1) colors = cm.viridis(np.linspace(0, 1, len(data["Category"].unique()))) for i, category in enumerate(data["Category"].unique()): accuracies = [category_accuracies[epoch].get(category, 0) for epoch in sorted(category_accuracies.keys())] plt.plot(sorted(category_accuracies.keys()), accuracies, label=category, color=colors[i]) plt.title('Accuracy by Category over Epochs') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.grid(True) plt.subplot(1, 2, 2) colors_subcat = cm.plasma(np.linspace(0, 1, len(data["SubCategory"].unique()))) for i, sub_category in enumerate(data["SubCategory"].unique()): accuracies = [sub_category_accuracies[epoch].get(sub_category, 0) for epoch in sorted(sub_category_accuracies.keys())] plt.plot(sorted(sub_category_accuracies.keys()), accuracies, label=sub_category, color=colors_subcat[i]) plt.title('Accuracy by Sub-Category over Epochs') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.grid(True) plt.tight_layout() plt.show()
Editor is loading...
Leave a Comment