Untitled
def train_epoch(model, train_loader, criterion, optimizer, device): model.train() train_epoch_loss = 0.0 train_correct = 0 train_total = 0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) # Zero the parameter gradients optimizer.zero_grad() # Forward pass outputs = model(inputs) loss = criterion(outputs, labels) # Backward pass and optimization loss.backward() optimizer.step() # Update metrics train_epoch_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) train_total += labels.size(0) train_correct += (predicted == labels).sum().item() train_epoch_loss /= train_total train_accuracy = train_correct / train_total return train_epoch_loss, train_accuracy def validate_epoch(model, val_loader, criterion, device): model.eval() val_epoch_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_epoch_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) val_total += labels.size(0) val_correct += (predicted == labels).sum().item() val_epoch_loss /= val_total val_accuracy = val_correct / val_total return val_epoch_loss, val_accuracy def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cpu', save_loss_plot=True): model.to(device) train_loss_history = [] val_loss_history = [] for epoch in range(num_epochs): train_epoch_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer, device) val_epoch_loss, val_accuracy = validate_epoch(model, val_loader, criterion, device) train_loss_history.append(train_epoch_loss) val_loss_history.append(val_epoch_loss) print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, " f"Val Loss: {val_epoch_loss:.4f}, Val Accuracy: {val_accuracy:.4f}") if save_loss_plot: plt.figure() plt.plot(range(1, num_epochs + 1), train_loss_history, label='Train Loss', marker='o') plt.plot(range(1, num_epochs + 1), val_loss_history, label='Val Loss', marker='o') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Training and Validation Loss Over Epochs') plt.legend() plt.grid(True) plt.savefig('training_val_loss_plot.png') print("Training and validation loss plot saved as 'training_val_loss_plot.png'") return model
Leave a Comment