Untitled
unknown
plain_text
a year ago
1.5 kB
10
Indexable
def plot_tsne(embeddings_2d, cluster_labels, file_names, largest_cluster_label=None):
plt.figure(figsize=(10, 8))
unique_labels = set(cluster_labels)
if largest_cluster_label ==-1:
for label in unique_labels:
indices = [i for i, l in enumerate(cluster_labels) if l == label]
plt.scatter(
embeddings_2d[indices, 0],
embeddings_2d[indices, 1],
label=f'Cluster {label + 1}'
)
else:
for label in unique_labels:
indices = [i for i, l in enumerate(cluster_labels) if l == label]
if label == largest_cluster_label:
plt.scatter(
embeddings_2d[indices, 0],
embeddings_2d[indices, 1],
label=f'Cluster {label + 1}, Real Samples'
)
else:
plt.scatter(
embeddings_2d[indices, 0],
embeddings_2d[indices, 1],
label=f'Cluster {label + 1}, Fake Samples'
)
for i, file_name in enumerate(file_names):
plt.annotate(file_name, (embeddings_2d[i, 0], embeddings_2d[i, 1]))
plt.legend()
plt.grid()
plt.title("t-SNE Visualization of Features")
plt.xlabel("t-SNE Component 1")
plt.ylabel("t-SNE Component 2")
plt.savefig(r'static\images\tsne.png') # Save the figure
plt.show()Editor is loading...
Leave a Comment