Untitled
unknown
plain_text
2 months ago
590 B
6
Indexable
def test(dataloader, model, loss_fn): size = len(dataloader.dataset) num_batches = len(dataloader) model.eval() test_loss, correct = 0, 0 with torch.no_grad(): for X, y in dataloader: X, y = X.to(DEVICE), y.to(DEVICE) pred = model(X) test_loss += loss_fn(pred, y).item() correct += (pred.round() == y).type(torch.float).sum().item() test_loss /= num_batches correct /= size print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") return correct, test_loss
Editor is loading...
Leave a Comment