Untitled

 avatar
unknown
plain_text
18 days ago
4.1 kB
7
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Model Setup ---------------------------------------------------------------
model = models.inception_v3(pretrained=True)
model.aux_logits = True  # Keep auxiliary classifier

# Unfreeze all layers
for param in model.parameters():
    param.requires_grad = True

# Replace final layers (both main and auxiliary)
num_classes = 2
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_ftrs, num_classes)
)

if model.AuxLogits is not None:
    num_aux_ftrs = model.AuxLogits.fc.in_features
    model.AuxLogits.fc = nn.Linear(num_aux_aux_ftrs, num_classes)

model = model.to(device)

# 2. Optimizer & Scheduler -----------------------------------------------------
optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-4,            # Lower initial LR for fine-tuning
    weight_decay=1e-2   # Regularization
)

scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',         # Monitor validation accuracy
    factor=0.5,         # Reduce LR by 50%
    patience=2,         # Wait 2 epochs w/o improvement
    verbose=True
)

# 3. Data Augmentation ---------------------------------------------------------
train_transforms = transforms.Compose([
    transforms.Resize(342),          # For Inception-specific sizing
    transforms.RandomResizedCrop(299),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(342),
    transforms.CenterCrop(299),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 4. Training Loop with Best Practices -----------------------------------------
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()  # For mixed precision training
best_acc = 0.0

for epoch in range(50):  # Maximum 50 epochs
    # Training Phase
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Mixed precision forward
        with autocast():
            outputs, aux_outputs = model(inputs)
            loss1 = criterion(outputs, labels)
            loss2 = criterion(aux_outputs, labels)
            loss = loss1 + 0.4 * loss2  # Aux loss weight
            
        # Backward pass with scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * inputs.size(0)
    
    # Validation Phase
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    # Calculate metrics
    train_loss = running_loss / len(train_dataset)
    val_loss = val_loss / len(val_dataset)
    val_acc = 100 * correct / total
    
    # Update scheduler
    scheduler.step(val_acc)
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), 'best_inceptionv3.pth')
    
    print(f'Epoch {epoch+1}/50 | '
          f'Train Loss: {train_loss:.4f} | '
          f'Val Loss: {val_loss:.4f} | '
          f'Val Acc: {val_acc:.2f}% | '
          f'LR: {optimizer.param_groups[0]["lr"]:.2e}')
Leave a Comment