Untitled
unknown
plain_text
6 months ago
983 B
6
Indexable
def normalize_embeddings(embedding_corpus): return embedding_corpus / np.linalg.norm(embedding_corpus, axis=1)[:,jnp.newaxis] def calculate_embedding_similarity_with_corpus(embedding_corpus, embedding_1): # embedding corpus should be a jax array of B x p # embedding_1 should be an array of p vmap_function = jax.vmap(lambda x: x.dot(embedding_1), in_axes = (0)) return jnp.linalg.norm(vmap_function(embedding_corpus), ord=2) def add_corpus_similarity_column_and_sort(dataset_with_embeddings, original_embedding_corpus: jnp.array): normalized_corpus = normalize_embeddings(original_embedding_corpus) embedding_distances = jax.vmap(lambda x: calculate_embedding_similarity_with_corpus(normalized_corpus, x))(normalize_embeddings(dataset_with_embeddings['message_embedding'])) dataset_with_embeddings = dataset_with_embeddings.add_column('corpus_similarity', np.array(embedding_distances)) return dataset_with_embeddings.sort('corpus_similarity')
Editor is loading...
Leave a Comment