Untitled
unknown
plain_text
a year ago
1.5 kB
5
Indexable
def plot_tsne(embeddings_2d, cluster_labels, file_names, largest_cluster_label=None): plt.figure(figsize=(10, 8)) unique_labels = set(cluster_labels) if largest_cluster_label ==-1: for label in unique_labels: indices = [i for i, l in enumerate(cluster_labels) if l == label] plt.scatter( embeddings_2d[indices, 0], embeddings_2d[indices, 1], label=f'Cluster {label + 1}' ) else: for label in unique_labels: indices = [i for i, l in enumerate(cluster_labels) if l == label] if label == largest_cluster_label: plt.scatter( embeddings_2d[indices, 0], embeddings_2d[indices, 1], label=f'Cluster {label + 1}, Real Samples' ) else: plt.scatter( embeddings_2d[indices, 0], embeddings_2d[indices, 1], label=f'Cluster {label + 1}, Fake Samples' ) for i, file_name in enumerate(file_names): plt.annotate(file_name, (embeddings_2d[i, 0], embeddings_2d[i, 1])) plt.legend() plt.grid() plt.title("t-SNE Visualization of Features") plt.xlabel("t-SNE Component 1") plt.ylabel("t-SNE Component 2") plt.savefig(r'static\images\tsne.png') # Save the figure plt.show()
Editor is loading...
Leave a Comment