Untitled

 avatar
unknown
plain_text
a month ago
2.8 kB
2
Indexable
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    train_epoch_loss = 0.0
    train_correct = 0
    train_total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # Zero the parameter gradients
        optimizer.zero_grad()
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        # Update metrics
        train_epoch_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    train_epoch_loss /= train_total
    train_accuracy = train_correct / train_total

    return train_epoch_loss, train_accuracy

def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    val_epoch_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_epoch_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_epoch_loss /= val_total
    val_accuracy = val_correct / val_total

    return val_epoch_loss, val_accuracy

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cpu', save_loss_plot=True):
    model.to(device)
    train_loss_history = []
    val_loss_history = []

    for epoch in range(num_epochs):
        train_epoch_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
        val_epoch_loss, val_accuracy = validate_epoch(model, val_loader, criterion, device)

        train_loss_history.append(train_epoch_loss)
        val_loss_history.append(val_epoch_loss)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
              f"Val Loss: {val_epoch_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

    if save_loss_plot:
        plt.figure()
        plt.plot(range(1, num_epochs + 1), train_loss_history, label='Train Loss', marker='o')
        plt.plot(range(1, num_epochs + 1), val_loss_history, label='Val Loss', marker='o')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss Over Epochs')
        plt.legend()
        plt.grid(True)
        plt.savefig('training_val_loss_plot.png')
        print("Training and validation loss plot saved as 'training_val_loss_plot.png'")

    return model
Leave a Comment