Untitled
unknown
plain_text
a year ago
934 B
3
Indexable
def train_model(model, B, N, n_training_steps): for _ in range(n_training_steps): batch_datasets = sample_datasets(B, N) ## BxN dimensional tensor batch_queries = sample_queries(B, batch_datasets) ## B dimensional vector batch_nns = compute_nn(batch_datasets, batch_queries) ## B dimensional vector predicted_nns = model(batch_datasets, batch_queries) loss = compute_mse(predicted_nns, batch_nns) model.update(loss) def eval_model(model, n_samples, N): eval_datasets = sample_datasets(n_samples, N) ## n_samples x N dimensional tensor eval_queries = sample_queries(B, eval_datasets) ## n_samples dimensional vector eval_nns = compute_nn(eval_datasets, eval_queries) ## n_samples dimensional vector predicted_nns = model(eval_datasets, eval_queries) loss = compute_mse(predicted_nns, eval_nns) accuracy = compute_acc(predicted_nns, eval_nns)
Editor is loading...
Leave a Comment