Untitled
unknown
plain_text
a year ago
934 B
4
Indexable
def train_model(model, B, N, n_training_steps):
for _ in range(n_training_steps):
batch_datasets = sample_datasets(B, N) ## BxN dimensional tensor
batch_queries = sample_queries(B, batch_datasets) ## B dimensional vector
batch_nns = compute_nn(batch_datasets, batch_queries) ## B dimensional vector
predicted_nns = model(batch_datasets, batch_queries)
loss = compute_mse(predicted_nns, batch_nns)
model.update(loss)
def eval_model(model, n_samples, N):
eval_datasets = sample_datasets(n_samples, N) ## n_samples x N dimensional tensor
eval_queries = sample_queries(B, eval_datasets) ## n_samples dimensional vector
eval_nns = compute_nn(eval_datasets, eval_queries) ## n_samples dimensional vector
predicted_nns = model(eval_datasets, eval_queries)
loss = compute_mse(predicted_nns, eval_nns)
accuracy = compute_acc(predicted_nns, eval_nns)Editor is loading...
Leave a Comment