Untitled

 avatar
unknown
plain_text
11 days ago
5.6 kB
6
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
import numpy as np

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Data Preparation ---------------------------------------------------------
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(342),
        transforms.RandomResizedCrop(299),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.2, 0.2, 0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(342),
        transforms.CenterCrop(299),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Load dataset from folders
full_dataset = datasets.ImageFolder(
    root='path/to/parent_folder',  # Should contain cancer/no_cancer subfolders
    transform=data_transforms['train']
)

# Split dataset (80% train, 20% val)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Update validation transform
val_dataset.dataset.transform = data_transforms['val']

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# 2. Model Setup --------------------------------------------------------------
model = models.inception_v3(pretrained=True)
model.aux_logits = True

# Unfreeze all layers
for param in model.parameters():
    param.requires_grad = True

# Modify final layers
num_classes = 2
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, num_classes)
)

if model.AuxLogits is not None:
    model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, num_classes)

model = model.to(device)

# 3. Training Configuration ---------------------------------------------------
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# 4. Enhanced Metrics Calculation ---------------------------------------------
def calculate_metrics(outputs, labels):
    _, preds = torch.max(outputs, 1)
    
    # Confusion matrix
    tn = ((preds == 0) & (labels == 0)).sum().item()
    fp = ((preds == 1) & (labels == 0)).sum().item()
    fn = ((preds == 0) & (labels == 1)).sum().item()
    tp = ((preds == 1) & (labels == 1)).sum().item()
    
    # Calculate metrics
    accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-9)
    precision = tp / (tp + fp + 1e-9)
    recall = tp / (tp + fn + 1e-9)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
    balanced_acc = (recall + (tn / (tn + fp + 1e-9))) / 2
    
    return {
        'accuracy': accuracy,
        'balanced_accuracy': balanced_acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'tp': tp,
        'fp': fp,
        'tn': tn,
        'fn': fn
    }

# 5. Training Loop ------------------------------------------------------------
best_f1 = 0.0

for epoch in range(50):
    # Training
    model.train()
    train_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            outputs, aux_outputs = model(inputs)
            loss1 = criterion(outputs, labels)
            loss2 = criterion(aux_outputs, labels)
            loss = loss1 + 0.4 * loss2
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item() * inputs.size(0)
    
    # Validation
    model.eval()
    val_loss = 0.0
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            all_outputs.append(outputs)
            all_labels.append(labels)
    
    # Aggregate validation results
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    metrics = calculate_metrics(all_outputs, all_labels)
    
    # Update scheduler
    scheduler.step(metrics['f1'])
    
    # Save best model
    if metrics['f1'] > best_f1:
        best_f1 = metrics['f1']
        torch.save(model.state_dict(), 'best_model.pth')
    
    # Print results
    print(f"\nEpoch {epoch+1}/50")
    print(f"Train Loss: {train_loss/len(train_dataset):.4f}")
    print(f"Val Loss: {val_loss/len(val_dataset):.4f}")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Balanced Acc: {metrics['balanced_accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1-Score: {metrics['f1']:.4f}")
    print(f"Confusion Matrix:")
    print(f"[[{metrics['tn']} {metrics['fp']}]")
    print(f" [{metrics['fn']} {metrics['tp']}]]")
    print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
Leave a Comment