Untitled
unknown
plain_text
a year ago
2.3 kB
15
Indexable
from transformers import RobertaForQuestionAnswering
from torch.utils.data import DataLoader
import torch
# Load the model
model = RobertaForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")
# Set up optimizer
optimizer = AdamW(model.parameters(), lr=3e-5)
# Early stopping parameters
patience = 2 # Number of epochs to wait for improvement
best_val_loss = float('inf')
epochs_without_improvement = 0
epochs = 3
def evaluate(model, validation_dataloader):
model.eval()
total_loss = 0
for batch in validation_dataloader:
input_ids, attention_mask, start_positions, end_positions = batch
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
loss = outputs.loss
total_loss += loss.item()
return total_loss / len(validation_dataloader)
# Training loop
for epoch in range(epochs):
model.train()
total_loss = 0
for batch in train_dataloader:
optimizer.zero_grad()
# Unpack the inputs from the DataLoader
input_ids, attention_mask, start_positions, end_positions = batch
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
loss = outputs.loss
# Backward pass
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_train_loss = total_loss / len(train_dataloader)
# Evaluate on validation set
val_loss = evaluate(model, validation_dataloader)
print(f"Epoch: {epoch + 1}, Training Loss: {avg_train_loss}, Validation Loss: {val_loss}")
# Check for early stopping
if val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
print(f"Validation loss improved to {val_loss}. Saving model...")
# Save the model checkpoint if needed
model.save_pretrained("best_model")
else:
epochs_without_improvement += 1
if epochs_without_improvement >= patience:
print(f"Early stopping triggered. Stopping training.")
break
Editor is loading...
Leave a Comment