Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
2.2 kB
1
Indexable
Never
print('Beginning Model Training....')

#### TRAIN ####
for epoch in range(NUM_EPOCHS):
  model.train()

  train_loss = 0.0
  train_correct = 0
  train_total = 0 
  num_train_batches = 0

  # cycle through train set
  for index, batch in enumerate(dataloader_train):

    # batch of data and move to device
    inputs, labels = batch
    inputs, labels = inputs.to(device), labels.to(device)

    # zero the parameter gradients (new gradients per batch)
    optimizer.zero_grad()

    # forward pass, loss, backward pass, weight update
    outputs = model(inputs)
    loss    = criterion(outputs, labels)
    loss.backward()
    optimizer.step() # update

    # accuracy
    _, preds = torch.max(outputs.data, 1)
    train_total += labels.size(0)
    train_correct += (preds == labels).sum().item()
    
    num_train_batches += 1
    train_loss += loss.item()


  #### VALIDATION ####
  model.eval()
  val_loss = 0.0
  val_correct = 0
  val_total = 0
  num_val_batches = 0

  with torch.no_grad():
    
    # cycle through validation set
    for index, batch in enumerate(dataloader_val):
      inputs, labels = batch
      inputs, labels = inputs.to(device), labels.to(device)

      # forward pass, loss, predictions
      outputs = model(inputs)
      loss = criterion(outputs, labels)

      # validation accuracy
      _, preds = torch.max(outputs.data, 1)
      val_total += labels.size(0)
      val_correct += (preds == labels).sum().item()
      
      num_val_batches += 1
      val_loss += loss.item()


  #### EPOCH STATISTICS ####
  train_accuracy = (100.0*train_correct)/train_total
  val_accuracy = (100.0*val_correct)/val_total

  # loss per example 
  avg_train_loss = (train_loss/num_train_batches)/DATA_BATCH_SIZE
  avg_val_loss   = (val_loss/num_val_batches)/DATA_BATCH_SIZE

  train_losses.append(avg_train_loss)
  val_losses.append(avg_val_loss)

  print('Epoch {}/{} | avg train loss: {:.4f} | train accuracy: {:.3f} | avg val loss: {:.4f} | val accuracy: {:.3f}'.format(epoch+1, NUM_EPOCHS, avg_train_loss, train_accuracy, avg_val_loss, val_accuracy))