Untitled
unknown
plain_text
22 days ago
2.0 kB
2
Indexable
Never
train_loss_data, valid_loss_data = [], [] valid_loss_min = np.Inf since = time.time() best_loss = np.inf patience = 50 # Early stopping patience early_stop_counter = 0 # Counter for early stopping for epoch in range(epochs): print("Epoch: {}/{}".format(epoch + 1, epochs)) # Reset training and validation loss for the epoch train_loss = 0.0 valid_loss = 0.0 total = 0 correct = 0 e_since = time.time() # Train Model train_loss += train(model, train_dataloader, optimizer, criterion, tokenizer_config) # Now Evaluate out = evaluate(model, test_dataloader, criterion, tokenizer_config) total += out[0] correct += out[1] valid_loss += out[2] scheduler.step() # Calculate average loss over an epoch train_loss = train_loss / len(train_dataloader.dataset) valid_loss = valid_loss / len(test_dataloader.dataset) # Append losses for visualization train_loss_data.append(train_loss * 100) valid_loss_data.append(valid_loss * 100) # Check for validation loss improvement if valid_loss < best_loss: best_loss = valid_loss torch.save(model.state_dict(), "news_model1.pth") # Save best model early_stop_counter = 0 # Reset counter if validation loss improves else: early_stop_counter += 1 # Increment counter if no improvement # Print training/validation statistics print("\tTrain loss:{:.6f}..".format(train_loss), "\tValid Loss:{:.6f}..".format(valid_loss), "\tAccuracy: {:.4f}".format(correct / total * 100)) # Check for early stopping if early_stop_counter >= patience: print("Early stopping triggered after {} epochs with no improvement.".format(patience)) break time_elapsed = time.time() - since print('Training completed in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60))
Leave a Comment