Untitled

 avatar
unknown
plain_text
a year ago
1.5 kB
5
Indexable
def plot_tsne(embeddings_2d, cluster_labels, file_names, largest_cluster_label=None):
    plt.figure(figsize=(10, 8))
    unique_labels = set(cluster_labels)

    if largest_cluster_label ==-1:
        for label in unique_labels:
            indices = [i for i, l in enumerate(cluster_labels) if l == label]
            plt.scatter(
                embeddings_2d[indices, 0],
                embeddings_2d[indices, 1],
                label=f'Cluster {label + 1}'
            )
    else:
        for label in unique_labels:
            indices = [i for i, l in enumerate(cluster_labels) if l == label]
            if label == largest_cluster_label:
                plt.scatter(
                    embeddings_2d[indices, 0],
                    embeddings_2d[indices, 1],
                    label=f'Cluster {label + 1}, Real Samples'
                )
            else:
                plt.scatter(
                    embeddings_2d[indices, 0],
                    embeddings_2d[indices, 1],
                    label=f'Cluster {label + 1}, Fake Samples'
                )

    for i, file_name in enumerate(file_names):
        plt.annotate(file_name, (embeddings_2d[i, 0], embeddings_2d[i, 1]))

    plt.legend()
    plt.grid()
    plt.title("t-SNE Visualization of Features")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    plt.savefig(r'static\images\tsne.png')  # Save the figure
    plt.show()
Editor is loading...
Leave a Comment