Untitled
unknown
plain_text
4 years ago
2.2 kB
10
Indexable
print('Beginning Model Training....')
#### TRAIN ####
for epoch in range(NUM_EPOCHS):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
num_train_batches = 0
# cycle through train set
for index, batch in enumerate(dataloader_train):
# batch of data and move to device
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients (new gradients per batch)
optimizer.zero_grad()
# forward pass, loss, backward pass, weight update
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step() # update
# accuracy
_, preds = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (preds == labels).sum().item()
num_train_batches += 1
train_loss += loss.item()
#### VALIDATION ####
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
num_val_batches = 0
with torch.no_grad():
# cycle through validation set
for index, batch in enumerate(dataloader_val):
inputs, labels = batch
inputs, labels = inputs.to(device), labels.to(device)
# forward pass, loss, predictions
outputs = model(inputs)
loss = criterion(outputs, labels)
# validation accuracy
_, preds = torch.max(outputs.data, 1)
val_total += labels.size(0)
val_correct += (preds == labels).sum().item()
num_val_batches += 1
val_loss += loss.item()
#### EPOCH STATISTICS ####
train_accuracy = (100.0*train_correct)/train_total
val_accuracy = (100.0*val_correct)/val_total
# loss per example
avg_train_loss = (train_loss/num_train_batches)/DATA_BATCH_SIZE
avg_val_loss = (val_loss/num_val_batches)/DATA_BATCH_SIZE
train_losses.append(avg_train_loss)
val_losses.append(avg_val_loss)
print('Epoch {}/{} | avg train loss: {:.4f} | train accuracy: {:.3f} | avg val loss: {:.4f} | val accuracy: {:.3f}'.format(epoch+1, NUM_EPOCHS, avg_train_loss, train_accuracy, avg_val_loss, val_accuracy))
Editor is loading...