Untitled

 avatar
unknown
plain_text
6 months ago
983 B
6
Indexable
def normalize_embeddings(embedding_corpus):
    return embedding_corpus / np.linalg.norm(embedding_corpus, axis=1)[:,jnp.newaxis]

def calculate_embedding_similarity_with_corpus(embedding_corpus, embedding_1):
    # embedding corpus should be a jax array of B x p
    # embedding_1 should be an array of p
    vmap_function = jax.vmap(lambda x: x.dot(embedding_1), in_axes = (0))
    return jnp.linalg.norm(vmap_function(embedding_corpus), ord=2)

def add_corpus_similarity_column_and_sort(dataset_with_embeddings, original_embedding_corpus: jnp.array):
    normalized_corpus = normalize_embeddings(original_embedding_corpus)
    embedding_distances = jax.vmap(lambda x: calculate_embedding_similarity_with_corpus(normalized_corpus, x))(normalize_embeddings(dataset_with_embeddings['message_embedding']))
    dataset_with_embeddings = dataset_with_embeddings.add_column('corpus_similarity', np.array(embedding_distances))

    return dataset_with_embeddings.sort('corpus_similarity')
Editor is loading...
Leave a Comment