Untitled
unknown
plain_text
10 months ago
882 B
11
Indexable
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.train()
train_loss, correct = 0, 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(DEVICE), y.to(DEVICE)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
train_loss += loss.item()
correct += (pred.round() == y).type(torch.float).sum().item()
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
train_loss /= num_batches
correct /= size
print(f"Train Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {train_loss:>8f} \n")Editor is loading...
Leave a Comment