Untitled
unknown
plain_text
a year ago
2.8 kB
5
Indexable
def train_epoch(model, train_loader, criterion, optimizer, device):
model.train()
train_epoch_loss = 0.0
train_correct = 0
train_total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimization
loss.backward()
optimizer.step()
# Update metrics
train_epoch_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
train_epoch_loss /= train_total
train_accuracy = train_correct / train_total
return train_epoch_loss, train_accuracy
def validate_epoch(model, val_loader, criterion, device):
model.eval()
val_epoch_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_epoch_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_epoch_loss /= val_total
val_accuracy = val_correct / val_total
return val_epoch_loss, val_accuracy
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5, device='cpu', save_loss_plot=True):
model.to(device)
train_loss_history = []
val_loss_history = []
for epoch in range(num_epochs):
train_epoch_loss, train_accuracy = train_epoch(model, train_loader, criterion, optimizer, device)
val_epoch_loss, val_accuracy = validate_epoch(model, val_loader, criterion, device)
train_loss_history.append(train_epoch_loss)
val_loss_history.append(val_epoch_loss)
print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_epoch_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
f"Val Loss: {val_epoch_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
if save_loss_plot:
plt.figure()
plt.plot(range(1, num_epochs + 1), train_loss_history, label='Train Loss', marker='o')
plt.plot(range(1, num_epochs + 1), val_loss_history, label='Val Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.savefig('training_val_loss_plot.png')
print("Training and validation loss plot saved as 'training_val_loss_plot.png'")
return modelEditor is loading...
Leave a Comment