Untitled
unknown
plain_text
a year ago
2.0 kB
9
Indexable
train_loss_data, valid_loss_data = [], []
valid_loss_min = np.Inf
since = time.time()
best_loss = np.inf
patience = 50 # Early stopping patience
early_stop_counter = 0 # Counter for early stopping
for epoch in range(epochs):
print("Epoch: {}/{}".format(epoch + 1, epochs))
# Reset training and validation loss for the epoch
train_loss = 0.0
valid_loss = 0.0
total = 0
correct = 0
e_since = time.time()
# Train Model
train_loss += train(model, train_dataloader, optimizer, criterion, tokenizer_config)
# Now Evaluate
out = evaluate(model, test_dataloader, criterion, tokenizer_config)
total += out[0]
correct += out[1]
valid_loss += out[2]
scheduler.step()
# Calculate average loss over an epoch
train_loss = train_loss / len(train_dataloader.dataset)
valid_loss = valid_loss / len(test_dataloader.dataset)
# Append losses for visualization
train_loss_data.append(train_loss * 100)
valid_loss_data.append(valid_loss * 100)
# Check for validation loss improvement
if valid_loss < best_loss:
best_loss = valid_loss
torch.save(model.state_dict(), "news_model1.pth") # Save best model
early_stop_counter = 0 # Reset counter if validation loss improves
else:
early_stop_counter += 1 # Increment counter if no improvement
# Print training/validation statistics
print("\tTrain loss:{:.6f}..".format(train_loss),
"\tValid Loss:{:.6f}..".format(valid_loss),
"\tAccuracy: {:.4f}".format(correct / total * 100))
# Check for early stopping
if early_stop_counter >= patience:
print("Early stopping triggered after {} epochs with no improvement.".format(patience))
break
time_elapsed = time.time() - since
print('Training completed in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
Editor is loading...
Leave a Comment