Untitled
unknown
plain_text
3 years ago
2.2 kB
3
Indexable
print('Beginning Model Training....') #### TRAIN #### for epoch in range(NUM_EPOCHS): model.train() train_loss = 0.0 train_correct = 0 train_total = 0 num_train_batches = 0 # cycle through train set for index, batch in enumerate(dataloader_train): # batch of data and move to device inputs, labels = batch inputs, labels = inputs.to(device), labels.to(device) # zero the parameter gradients (new gradients per batch) optimizer.zero_grad() # forward pass, loss, backward pass, weight update outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # update # accuracy _, preds = torch.max(outputs.data, 1) train_total += labels.size(0) train_correct += (preds == labels).sum().item() num_train_batches += 1 train_loss += loss.item() #### VALIDATION #### model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 num_val_batches = 0 with torch.no_grad(): # cycle through validation set for index, batch in enumerate(dataloader_val): inputs, labels = batch inputs, labels = inputs.to(device), labels.to(device) # forward pass, loss, predictions outputs = model(inputs) loss = criterion(outputs, labels) # validation accuracy _, preds = torch.max(outputs.data, 1) val_total += labels.size(0) val_correct += (preds == labels).sum().item() num_val_batches += 1 val_loss += loss.item() #### EPOCH STATISTICS #### train_accuracy = (100.0*train_correct)/train_total val_accuracy = (100.0*val_correct)/val_total # loss per example avg_train_loss = (train_loss/num_train_batches)/DATA_BATCH_SIZE avg_val_loss = (val_loss/num_val_batches)/DATA_BATCH_SIZE train_losses.append(avg_train_loss) val_losses.append(avg_val_loss) print('Epoch {}/{} | avg train loss: {:.4f} | train accuracy: {:.3f} | avg val loss: {:.4f} | val accuracy: {:.3f}'.format(epoch+1, NUM_EPOCHS, avg_train_loss, train_accuracy, avg_val_loss, val_accuracy))
Editor is loading...