Untitled
import torch import torch.nn as nn import torch.optim as optim from torchvision import models, transforms, datasets from torch.utils.data import DataLoader, random_split from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.cuda.amp import GradScaler, autocast import numpy as np # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. Data Preparation --------------------------------------------------------- data_transforms = { 'train': transforms.Compose([ transforms.Resize(342), transforms.RandomResizedCrop(299), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(0.2, 0.2, 0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize(342), transforms.CenterCrop(299), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) } # Load dataset from folders full_dataset = datasets.ImageFolder( root='path/to/parent_folder', # Should contain cancer/no_cancer subfolders transform=data_transforms['train'] ) # Split dataset (80% train, 20% val) train_size = int(0.8 * len(full_dataset)) val_size = len(full_dataset) - train_size train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) # Update validation transform val_dataset.dataset.transform = data_transforms['val'] train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4) # 2. Model Setup -------------------------------------------------------------- model = models.inception_v3(pretrained=True) model.aux_logits = True # Unfreeze all layers for param in model.parameters(): param.requires_grad = True # Modify final layers num_classes = 2 model.fc = nn.Sequential( nn.Dropout(0.5), nn.Linear(model.fc.in_features, num_classes) ) if model.AuxLogits is not None: model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, num_classes) model = model.to(device) # 3. Training Configuration --------------------------------------------------- optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2) scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2) criterion = nn.CrossEntropyLoss() scaler = GradScaler() # 4. Enhanced Metrics Calculation --------------------------------------------- def calculate_metrics(outputs, labels): _, preds = torch.max(outputs, 1) # Confusion matrix tn = ((preds == 0) & (labels == 0)).sum().item() fp = ((preds == 1) & (labels == 0)).sum().item() fn = ((preds == 0) & (labels == 1)).sum().item() tp = ((preds == 1) & (labels == 1)).sum().item() # Calculate metrics accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-9) precision = tp / (tp + fp + 1e-9) recall = tp / (tp + fn + 1e-9) f1 = 2 * (precision * recall) / (precision + recall + 1e-9) balanced_acc = (recall + (tn / (tn + fp + 1e-9))) / 2 return { 'accuracy': accuracy, 'balanced_accuracy': balanced_acc, 'precision': precision, 'recall': recall, 'f1': f1, 'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn } # 5. Training Loop ------------------------------------------------------------ best_f1 = 0.0 for epoch in range(50): # Training model.train() train_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() with autocast(): outputs, aux_outputs = model(inputs) loss1 = criterion(outputs, labels) loss2 = criterion(aux_outputs, labels) loss = loss1 + 0.4 * loss2 scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() train_loss += loss.item() * inputs.size(0) # Validation model.eval() val_loss = 0.0 all_outputs = [] all_labels = [] with torch.no_grad(): for inputs, labels in val_loader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() * inputs.size(0) all_outputs.append(outputs) all_labels.append(labels) # Aggregate validation results all_outputs = torch.cat(all_outputs) all_labels = torch.cat(all_labels) metrics = calculate_metrics(all_outputs, all_labels) # Update scheduler scheduler.step(metrics['f1']) # Save best model if metrics['f1'] > best_f1: best_f1 = metrics['f1'] torch.save(model.state_dict(), 'best_model.pth') # Print results print(f"\nEpoch {epoch+1}/50") print(f"Train Loss: {train_loss/len(train_dataset):.4f}") print(f"Val Loss: {val_loss/len(val_dataset):.4f}") print(f"Accuracy: {metrics['accuracy']:.4f}") print(f"Balanced Acc: {metrics['balanced_accuracy']:.4f}") print(f"Precision: {metrics['precision']:.4f}") print(f"Recall: {metrics['recall']:.4f}") print(f"F1-Score: {metrics['f1']:.4f}") print(f"Confusion Matrix:") print(f"[[{metrics['tn']} {metrics['fp']}]") print(f" [{metrics['fn']} {metrics['tp']}]]") print(f"LR: {optimizer.param_groups[0]['lr']:.2e}")
Leave a Comment