Untitled

mail@pastecode.io avatar
unknown
plain_text
18 days ago
2.3 kB
2
Indexable
Never
from transformers import RobertaForQuestionAnswering
from torch.utils.data import DataLoader
import torch

# Load the model
model = RobertaForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")

# Set up optimizer
optimizer = AdamW(model.parameters(), lr=3e-5)

# Early stopping parameters
patience = 2  # Number of epochs to wait for improvement
best_val_loss = float('inf')
epochs_without_improvement = 0
epochs = 3

def evaluate(model, validation_dataloader):
    model.eval()
    total_loss = 0
    for batch in validation_dataloader:
        input_ids, attention_mask, start_positions, end_positions = batch
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
            loss = outputs.loss
            total_loss += loss.item()
    return total_loss / len(validation_dataloader)

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        optimizer.zero_grad()
        
        # Unpack the inputs from the DataLoader
        input_ids, attention_mask, start_positions, end_positions = batch
        
        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_dataloader)
    
    # Evaluate on validation set
    val_loss = evaluate(model, validation_dataloader)
    
    print(f"Epoch: {epoch + 1}, Training Loss: {avg_train_loss}, Validation Loss: {val_loss}")
    
    # Check for early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        print(f"Validation loss improved to {val_loss}. Saving model...")
        # Save the model checkpoint if needed
        model.save_pretrained("best_model")
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered. Stopping training.")
            break
Leave a Comment