Untitled
unknown
plain_text
9 months ago
5.6 kB
8
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast
import numpy as np
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. Data Preparation ---------------------------------------------------------
data_transforms = {
'train': transforms.Compose([
transforms.Resize(342),
transforms.RandomResizedCrop(299),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(0.2, 0.2, 0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(342),
transforms.CenterCrop(299),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
# Load dataset from folders
full_dataset = datasets.ImageFolder(
root='path/to/parent_folder', # Should contain cancer/no_cancer subfolders
transform=data_transforms['train']
)
# Split dataset (80% train, 20% val)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
# Update validation transform
val_dataset.dataset.transform = data_transforms['val']
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
# 2. Model Setup --------------------------------------------------------------
model = models.inception_v3(pretrained=True)
model.aux_logits = True
# Unfreeze all layers
for param in model.parameters():
param.requires_grad = True
# Modify final layers
num_classes = 2
model.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(model.fc.in_features, num_classes)
)
if model.AuxLogits is not None:
model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, num_classes)
model = model.to(device)
# 3. Training Configuration ---------------------------------------------------
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()
# 4. Enhanced Metrics Calculation ---------------------------------------------
def calculate_metrics(outputs, labels):
_, preds = torch.max(outputs, 1)
# Confusion matrix
tn = ((preds == 0) & (labels == 0)).sum().item()
fp = ((preds == 1) & (labels == 0)).sum().item()
fn = ((preds == 0) & (labels == 1)).sum().item()
tp = ((preds == 1) & (labels == 1)).sum().item()
# Calculate metrics
accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-9)
precision = tp / (tp + fp + 1e-9)
recall = tp / (tp + fn + 1e-9)
f1 = 2 * (precision * recall) / (precision + recall + 1e-9)
balanced_acc = (recall + (tn / (tn + fp + 1e-9))) / 2
return {
'accuracy': accuracy,
'balanced_accuracy': balanced_acc,
'precision': precision,
'recall': recall,
'f1': f1,
'tp': tp,
'fp': fp,
'tn': tn,
'fn': fn
}
# 5. Training Loop ------------------------------------------------------------
best_f1 = 0.0
for epoch in range(50):
# Training
model.train()
train_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with autocast():
outputs, aux_outputs = model(inputs)
loss1 = criterion(outputs, labels)
loss2 = criterion(aux_outputs, labels)
loss = loss1 + 0.4 * loss2
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
train_loss += loss.item() * inputs.size(0)
# Validation
model.eval()
val_loss = 0.0
all_outputs = []
all_labels = []
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
all_outputs.append(outputs)
all_labels.append(labels)
# Aggregate validation results
all_outputs = torch.cat(all_outputs)
all_labels = torch.cat(all_labels)
metrics = calculate_metrics(all_outputs, all_labels)
# Update scheduler
scheduler.step(metrics['f1'])
# Save best model
if metrics['f1'] > best_f1:
best_f1 = metrics['f1']
torch.save(model.state_dict(), 'best_model.pth')
# Print results
print(f"\nEpoch {epoch+1}/50")
print(f"Train Loss: {train_loss/len(train_dataset):.4f}")
print(f"Val Loss: {val_loss/len(val_dataset):.4f}")
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Balanced Acc: {metrics['balanced_accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1-Score: {metrics['f1']:.4f}")
print(f"Confusion Matrix:")
print(f"[[{metrics['tn']} {metrics['fp']}]")
print(f" [{metrics['fn']} {metrics['tp']}]]")
print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")Editor is loading...
Leave a Comment