Untitled
unknown
plain_text
2 months ago
882 B
8
Indexable
def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) num_batches = len(dataloader) model.train() train_loss, correct = 0, 0 for batch, (X, y) in enumerate(dataloader): X, y = X.to(DEVICE), y.to(DEVICE) # Compute prediction error pred = model(X) loss = loss_fn(pred, y) train_loss += loss.item() correct += (pred.round() == y).type(torch.float).sum().item() # Backpropagation loss.backward() optimizer.step() optimizer.zero_grad() if batch % 100 == 0: loss, current = loss.item(), (batch + 1) * len(X) print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") train_loss /= num_batches correct /= size print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f} \n")
Editor is loading...
Leave a Comment