Untitled

 avatar
unknown
plain_text
a month ago
5.7 kB
5
Indexable
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from dataset import load_tiny_imagenet
from tqdm import tqdm

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        """
        初始化 Residual Block。
        Args:
            in_channels (int): 輸入通道數。
            out_channels (int): 輸出通道數。
            stride (int): 卷積的步幅。
        """
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 如果輸入和輸出通道數不同,或者步幅不為 1,使用 downsample 升維或降維
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        # 使用 nn.Sequential 來定義卷積層
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            ResidualBlock(64, 128, stride=2),
            ResidualBlock(128, 256, stride=2),
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 1024, stride=2)  # 最終特徵大小為 4x4
        )

    def forward(self, x):
        return self.encoder(x)

# Classifier 定義
class Classifier(nn.Module):
    def __init__(self, num_classes=200):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = x.mean((2,3)) # Global Average Pooling
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Top-N 準確率計算
def eval_top_n_accuracy(predictions, labels, n=1):
    _, top_n_preds = predictions.topk(n, dim=1, largest=True, sorted=True)
    correct = top_n_preds.eq(labels.view(-1, 1).expand_as(top_n_preds))
    return correct.sum().item() / labels.size(0)

# 訓練模型
def train_model(encoder, classifier, train_loader, test_loader, device, epochs=10, lr=1e-3, top_n=5):
    encoder.to(device)
    classifier.to(device)

    # 定義損失函數和優化器
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=lr)

    for epoch in range(epochs):
        encoder.train()
        classifier.train()

        running_loss = 0.0
        total_samples = 0
        total_correct = 0

        for images, labels in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            features = encoder(images)
            outputs = classifier(features)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 記錄訓練損失和準確率
            running_loss += loss.item() * labels.size(0)
            _, preds = outputs.max(1)
            total_correct += preds.eq(labels).sum().item()
            total_samples += labels.size(0)

        train_loss = running_loss / total_samples
        train_acc = total_correct / total_samples

        # 評估模型
        encoder.eval()
        classifier.eval()
        total_top1_acc = 0
        total_topn_acc = 0
        total_samples = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                features = encoder(images)
                outputs = classifier(features)
                total_top1_acc += eval_top_n_accuracy(outputs, labels, n=1) * labels.size(0)
                total_topn_acc += eval_top_n_accuracy(outputs, labels, n=top_n) * labels.size(0)
                total_samples += labels.size(0)

        top1_acc = total_top1_acc / total_samples
        topn_acc = total_topn_acc / total_samples

        print(f"Epoch [{epoch+1}/{epochs}] - Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
              f"Top-1 Acc: {top1_acc:.4f}, Top-{top_n} Acc: {topn_acc:.4f}")

if __name__ == "__main__":
    # Tiny ImageNet 資料集路徑
    data_dir = "./datasets/tiny-imagenet-200/"
    batch_size = 256
    epochs = 20
    learning_rate = 1e-4
    top_n = 5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 加載資料集
    train_loader, test_loader = load_tiny_imagenet(data_dir, batch_size)

    # 初始化模型
    encoder = Encoder()
    classifier = Classifier(num_classes=200)

    # 訓練模型
    train_model(encoder, classifier, train_loader, test_loader, device, epochs, learning_rate, top_n)
Leave a Comment