Untitled
unknown
plain_text
10 months ago
4.1 kB
9
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Model Setup ---------------------------------------------------------------
model = models.inception_v3(pretrained=True)
model.aux_logits = True # Keep auxiliary classifier
# Unfreeze all layers
for param in model.parameters():
param.requires_grad = True
# Replace final layers (both main and auxiliary)
num_classes = 2
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_ftrs, num_classes)
)
if model.AuxLogits is not None:
num_aux_ftrs = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Linear(num_aux_aux_ftrs, num_classes)
model = model.to(device)
# 2. Optimizer & Scheduler -----------------------------------------------------
optimizer = optim.AdamW(
model.parameters(),
lr=1e-4, # Lower initial LR for fine-tuning
weight_decay=1e-2 # Regularization
)
scheduler = ReduceLROnPlateau(
optimizer,
mode='max', # Monitor validation accuracy
factor=0.5, # Reduce LR by 50%
patience=2, # Wait 2 epochs w/o improvement
verbose=True
)
# 3. Data Augmentation ---------------------------------------------------------
train_transforms = transforms.Compose([
transforms.Resize(342), # For Inception-specific sizing
transforms.RandomResizedCrop(299),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(342),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 4. Training Loop with Best Practices -----------------------------------------
criterion = nn.CrossEntropyLoss()
scaler = GradScaler() # For mixed precision training
best_acc = 0.0
for epoch in range(50): # Maximum 50 epochs
# Training Phase
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
# Mixed precision forward
with autocast():
outputs, aux_outputs = model(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4 * loss2 # Aux loss weight
# Backward pass with scaling
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item() * inputs.size(0)
# Validation Phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# Calculate metrics
train_loss = running_loss / len(train_dataset)
val_loss = val_loss / len(val_dataset)
val_acc = 100 * correct / total
# Update scheduler
scheduler.step(val_acc)
# Save best model
if val_acc > best_acc:
best_acc = val_acc
torch.save(model.state_dict(), 'best_inceptionv3.pth')
print(f'Epoch {epoch+1}/50 | '
f'Train Loss: {train_loss:.4f} | '
f'Val Loss: {val_loss:.4f} | '
f'Val Acc: {val_acc:.2f}% | '
f'LR: {optimizer.param_groups[0]["lr"]:.2e}')Editor is loading...
Leave a Comment