Untitled

mail@pastecode.io avatar
unknown
plain_text
a month ago
934 B
0
Indexable
Never
def train_model(model, B, N, n_training_steps):
    for _ in range(n_training_steps):
        batch_datasets = sample_datasets(B, N) ## BxN dimensional tensor
        batch_queries = sample_queries(B, batch_datasets) ## B dimensional vector
        batch_nns = compute_nn(batch_datasets, batch_queries) ## B dimensional vector
        
        predicted_nns = model(batch_datasets, batch_queries)

        loss = compute_mse(predicted_nns, batch_nns)

        model.update(loss)

def eval_model(model, n_samples, N):
    eval_datasets = sample_datasets(n_samples, N) ## n_samples x N dimensional tensor
    eval_queries = sample_queries(B, eval_datasets) ## n_samples dimensional vector
    eval_nns = compute_nn(eval_datasets, eval_queries) ## n_samples dimensional vector
    
    predicted_nns = model(eval_datasets, eval_queries)

    loss = compute_mse(predicted_nns, eval_nns)
    accuracy = compute_acc(predicted_nns, eval_nns)
Leave a Comment