Untitled
unknown
plain_text
a year ago
5.7 kB
10
Indexable
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms, models import os import matplotlib.pyplot as plt from torchvision.models import ResNet18_Weights from tqdm import tqdm data_transforms = { 'train': transforms.Compose([ transforms.Resize((416, 416)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Resize((416, 416)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), } data_dir = r"C:\Users\Emman\Desktop\corn\dataset\dataset" image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=12, shuffle=True, num_workers=12) for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes num_classes = len(class_names) model = models.resnet18(weights=ResNet18_Weights.DEFAULT) for name, param in model.named_parameters(): if "fc" in name: param.requires_grad = True else: param.requires_grad = False model.fc = nn.Sequential( nn.Linear(model.fc.in_features, 512), nn.ReLU(), nn.Linear(512, 3) ) # Loss and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # Training function def train_model(num_epochs=10): train_loss_history = [] val_loss_history = [] train_acc_history = [] val_acc_history = [] class_correct = {class_name: 0 for class_name in class_names} class_total = {class_name: 0 for class_name in class_names} for epoch in range(num_epochs): print(f'Epoch {epoch + 1}/{num_epochs}') for phase in ['train', 'val']: if phase == 'train': model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 class_correct = {class_name: 0 for class_name in class_names} class_total = {class_name: 0 for class_name in class_names} with tqdm(total=dataset_sizes[phase], desc=f"{phase.capitalize()} Progress", unit="img") as pbar: for inputs, labels in dataloaders[phase]: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) # Track per-class accuracy for i in range(len(labels)): label = labels[i] pred = preds[i] class_total[class_names[label]] += 1 if label == pred: class_correct[class_names[label]] += 1 pbar.update(inputs.size(0)) epoch_loss = running_loss / dataset_sizes[phase] epoch_acc = running_corrects.double() / dataset_sizes[phase] print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') if phase == 'train': train_loss_history.append(epoch_loss) train_acc_history.append(epoch_acc.item()) else: val_loss_history.append(epoch_loss) val_acc_history.append(epoch_acc.item()) # Print per-class accuracy after each phase print(f'Per-Class Accuracy for {phase}:') for class_name in class_names: acc = class_correct[class_name] / class_total[class_name] * 100 if class_total[class_name] > 0 else 0 print(f'Class {class_name}: {acc:.2f}%') print("Training complete!") return train_loss_history, val_loss_history, train_acc_history, val_acc_history def plot_loss_accuracy(train_loss, val_loss, train_acc, val_acc): plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(train_loss, label='Train Loss', color='blue', marker='o') plt.plot(val_loss, label='Val Loss', color='orange', marker='o') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('Loss over Epochs') plt.legend() plt.grid() plt.subplot(1, 2, 2) plt.plot(train_acc, label='Train Accuracy', color='green', marker='o') plt.plot(val_acc, label='Val Accuracy', color='red', marker='o') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.title('Accuracy over Epochs') plt.legend() plt.grid() plt.tight_layout() plt.show() # Main if __name__ == "__main__": train_loss, val_loss, train_acc, val_acc = train_model(num_epochs=10) # Save the model torch.save(model.state_dict(), 'corn_leaf_model.pth') print("Model saved to 'corn_leaf_model.pth'") plot_loss_accuracy(train_loss, val_loss, train_acc, val_acc)
Editor is loading...
Leave a Comment