Untitled

 avatar
unknown
plain_text
a year ago
5.7 kB
10
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import os
import matplotlib.pyplot as plt
from torchvision.models import ResNet18_Weights
from tqdm import tqdm

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((416, 416)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((416, 416)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}

data_dir = r"C:\Users\Emman\Desktop\corn\dataset\dataset"

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=12, shuffle=True, num_workers=12) for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
num_classes = len(class_names)


model = models.resnet18(weights=ResNet18_Weights.DEFAULT)

for name, param in model.named_parameters():
    if "fc" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False

model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(),
    nn.Linear(512, 3)
)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Training function
def train_model(num_epochs=10):
    train_loss_history = []
    val_loss_history = []
    train_acc_history = []
    val_acc_history = []

    class_correct = {class_name: 0 for class_name in class_names}
    class_total = {class_name: 0 for class_name in class_names}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch + 1}/{num_epochs}')

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            class_correct = {class_name: 0 for class_name in class_names}
            class_total = {class_name: 0 for class_name in class_names}

            with tqdm(total=dataset_sizes[phase], desc=f"{phase.capitalize()} Progress", unit="img") as pbar:
                for inputs, labels in dataloaders[phase]:
                    inputs, labels = inputs.to(device), labels.to(device)

                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                    # Track per-class accuracy
                    for i in range(len(labels)):
                        label = labels[i]
                        pred = preds[i]
                        class_total[class_names[label]] += 1
                        if label == pred:
                            class_correct[class_names[label]] += 1

                    pbar.update(inputs.size(0))

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'train':
                train_loss_history.append(epoch_loss)
                train_acc_history.append(epoch_acc.item())
            else:
                val_loss_history.append(epoch_loss)
                val_acc_history.append(epoch_acc.item())

            # Print per-class accuracy after each phase
            print(f'Per-Class Accuracy for {phase}:')
            for class_name in class_names:
                acc = class_correct[class_name] / class_total[class_name] * 100 if class_total[class_name] > 0 else 0
                print(f'Class {class_name}: {acc:.2f}%')

    print("Training complete!")
    return train_loss_history, val_loss_history, train_acc_history, val_acc_history

def plot_loss_accuracy(train_loss, val_loss, train_acc, val_acc):
    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.plot(train_loss, label='Train Loss', color='blue', marker='o')
    plt.plot(val_loss, label='Val Loss', color='orange', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid()

    plt.subplot(1, 2, 2)
    plt.plot(train_acc, label='Train Accuracy', color='green', marker='o')
    plt.plot(val_acc, label='Val Accuracy', color='red', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy over Epochs')
    plt.legend()
    plt.grid()

    plt.tight_layout()
    plt.show()

# Main
if __name__ == "__main__":
    train_loss, val_loss, train_acc, val_acc = train_model(num_epochs=10)

    # Save the model
    torch.save(model.state_dict(), 'corn_leaf_model.pth')
    print("Model saved to 'corn_leaf_model.pth'")

    plot_loss_accuracy(train_loss, val_loss, train_acc, val_acc)
Editor is loading...
Leave a Comment