Untitled
unknown
plain_text
18 days ago
2.3 kB
2
Indexable
Never
from transformers import RobertaForQuestionAnswering from torch.utils.data import DataLoader import torch # Load the model model = RobertaForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") # Set up optimizer optimizer = AdamW(model.parameters(), lr=3e-5) # Early stopping parameters patience = 2 # Number of epochs to wait for improvement best_val_loss = float('inf') epochs_without_improvement = 0 epochs = 3 def evaluate(model, validation_dataloader): model.eval() total_loss = 0 for batch in validation_dataloader: input_ids, attention_mask, start_positions, end_positions = batch with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions) loss = outputs.loss total_loss += loss.item() return total_loss / len(validation_dataloader) # Training loop for epoch in range(epochs): model.train() total_loss = 0 for batch in train_dataloader: optimizer.zero_grad() # Unpack the inputs from the DataLoader input_ids, attention_mask, start_positions, end_positions = batch # Forward pass outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions) loss = outputs.loss # Backward pass loss.backward() optimizer.step() total_loss += loss.item() avg_train_loss = total_loss / len(train_dataloader) # Evaluate on validation set val_loss = evaluate(model, validation_dataloader) print(f"Epoch: {epoch + 1}, Training Loss: {avg_train_loss}, Validation Loss: {val_loss}") # Check for early stopping if val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 print(f"Validation loss improved to {val_loss}. Saving model...") # Save the model checkpoint if needed model.save_pretrained("best_model") else: epochs_without_improvement += 1 if epochs_without_improvement >= patience: print(f"Early stopping triggered. Stopping training.") break
Leave a Comment