Untitled

mail@pastecode.io avatar
unknown
plain_text
22 days ago
2.0 kB
2
Indexable
Never
train_loss_data, valid_loss_data = [], []
valid_loss_min = np.Inf
since = time.time()
best_loss = np.inf
patience = 50  # Early stopping patience
early_stop_counter = 0  # Counter for early stopping

for epoch in range(epochs):
    print("Epoch: {}/{}".format(epoch + 1, epochs))
    
    # Reset training and validation loss for the epoch
    train_loss = 0.0
    valid_loss = 0.0
    total = 0
    correct = 0
    e_since = time.time()

    # Train Model
    train_loss += train(model, train_dataloader, optimizer, criterion, tokenizer_config)

    # Now Evaluate
    out = evaluate(model, test_dataloader, criterion, tokenizer_config)
    total += out[0]
    correct += out[1]
    valid_loss += out[2]

    scheduler.step()

    # Calculate average loss over an epoch
    train_loss = train_loss / len(train_dataloader.dataset)
    valid_loss = valid_loss / len(test_dataloader.dataset)

    # Append losses for visualization
    train_loss_data.append(train_loss * 100)
    valid_loss_data.append(valid_loss * 100)

    # Check for validation loss improvement
    if valid_loss < best_loss:
        best_loss = valid_loss
        torch.save(model.state_dict(), "news_model1.pth")  # Save best model
        early_stop_counter = 0  # Reset counter if validation loss improves
    else:
        early_stop_counter += 1  # Increment counter if no improvement
    
    # Print training/validation statistics
    print("\tTrain loss:{:.6f}..".format(train_loss),
          "\tValid Loss:{:.6f}..".format(valid_loss),
          "\tAccuracy: {:.4f}".format(correct / total * 100))

    # Check for early stopping
    if early_stop_counter >= patience:
        print("Early stopping triggered after {} epochs with no improvement.".format(patience))
        break

time_elapsed = time.time() - since
print('Training completed in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))
Leave a Comment