Untitled

 avatar
unknown
python
5 months ago
3.0 kB
4
Indexable
import torch
import torchvision
from torchvision import models, transforms, datasets

imagenet_stats = [(0.485, 0.456, 0.406), (0.229, 0.224, 0.225)]
        
valid_tfms = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(imagenet_stats[0], imagenet_stats[1])])     

batch_size = 196
trainset = torchvision.datasets.Food101(root='./food101', split="train",
                                        download=False, transform=valid_tfms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=6)

validset = torchvision.datasets.Food101(root='./food101', split="test",
                                       download=False, transform=valid_tfms)
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size,
                                         shuffle=False, num_workers=6)

assert trainset.classes==validset.classes
classes = trainset.classes
classe2idx = trainset.class_to_idx
num_classes = len(classes)

print("Number of classes =", num_classes)


import torch
import torchvision
from torchvision import models, transforms, datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
import numpy as np
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50 = models.resnet50(weights="IMAGENET1K_V2")
resnet50.to(device)
resnet50.eval()  # 設定為評估模式

# 隨機選取 20000 張訓練圖片
train_indices = torch.randperm(len(trainset))[:20000]
train_subset = torch.utils.data.Subset(trainset, train_indices)

trainloader = torch.utils.data.DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=4)
validloader = torch.utils.data.DataLoader(validset, batch_size=32, shuffle=False, num_workers=4)


def extract_features(dataloader, model):
    features, labels = [], []
    with torch.no_grad():
        for images, targets in tqdm(dataloader, desc="Extracting Features"):
            images = images.to(device)
            outputs = model(images)  # 提取最後 1000 維 logits
            features.append(outputs.cpu().numpy())
            labels.append(targets.numpy())
    return np.vstack(features), np.hstack(labels)
    
    
train_features, train_labels = extract_features(trainloader, resnet50)
valid_features, valid_labels = extract_features(validloader, resnet50)

print("Training Logistic Regression...")
log_reg = LogisticRegression(max_iter=500, random_state=0, solver= "saga", verbose= 1)
log_reg.fit(train_features, train_labels)

print("Evaluating Logistic Regression...")
valid_preds = log_reg.predict(valid_features)
accuracy = accuracy_score(valid_labels, valid_preds)
macro_f1 = f1_score(valid_labels, valid_preds, average="macro")

print(f"Validation Accuracy: {accuracy:.4f}")
print(f"Macro-average F1 Score: {macro_f1:.4f}")
Editor is loading...
Leave a Comment