Untitled
import os import torch import torch.nn as nn from torch.optim import Adam from dataset import load_tiny_imagenet from tqdm import tqdm class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): """ 初始化 Residual Block。 Args: in_channels (int): 輸入通道數。 out_channels (int): 輸出通道數。 stride (int): 卷積的步幅。 """ super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) # 如果輸入和輸出通道數不同,或者步幅不為 1,使用 downsample 升維或降維 self.downsample = None if stride != 1 or in_channels != out_channels: self.downsample = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x if self.downsample is not None: identity = self.downsample(x) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += identity out = self.relu(out) return out class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() # 使用 nn.Sequential 來定義卷積層 self.encoder = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ResidualBlock(64, 128, stride=2), ResidualBlock(128, 256, stride=2), ResidualBlock(256, 512, stride=2), ResidualBlock(512, 1024, stride=2) # 最終特徵大小為 4x4 ) def forward(self, x): return self.encoder(x) # Classifier 定義 class Classifier(nn.Module): def __init__(self, num_classes=200): super(Classifier, self).__init__() self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, num_classes) def forward(self, x): x = x.mean((2,3)) # Global Average Pooling x = torch.relu(self.fc1(x)) x = self.fc2(x) return x # Top-N 準確率計算 def eval_top_n_accuracy(predictions, labels, n=1): _, top_n_preds = predictions.topk(n, dim=1, largest=True, sorted=True) correct = top_n_preds.eq(labels.view(-1, 1).expand_as(top_n_preds)) return correct.sum().item() / labels.size(0) # 訓練模型 def train_model(encoder, classifier, train_loader, test_loader, device, epochs=10, lr=1e-3, top_n=5): encoder.to(device) classifier.to(device) # 定義損失函數和優化器 criterion = nn.CrossEntropyLoss() optimizer = Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=lr) for epoch in range(epochs): encoder.train() classifier.train() running_loss = 0.0 total_samples = 0 total_correct = 0 for images, labels in tqdm(train_loader): images, labels = images.to(device), labels.to(device) # Forward pass features = encoder(images) outputs = classifier(features) loss = criterion(outputs, labels) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # 記錄訓練損失和準確率 running_loss += loss.item() * labels.size(0) _, preds = outputs.max(1) total_correct += preds.eq(labels).sum().item() total_samples += labels.size(0) train_loss = running_loss / total_samples train_acc = total_correct / total_samples # 評估模型 encoder.eval() classifier.eval() total_top1_acc = 0 total_topn_acc = 0 total_samples = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) features = encoder(images) outputs = classifier(features) total_top1_acc += eval_top_n_accuracy(outputs, labels, n=1) * labels.size(0) total_topn_acc += eval_top_n_accuracy(outputs, labels, n=top_n) * labels.size(0) total_samples += labels.size(0) top1_acc = total_top1_acc / total_samples topn_acc = total_topn_acc / total_samples print(f"Epoch [{epoch+1}/{epochs}] - Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " f"Top-1 Acc: {top1_acc:.4f}, Top-{top_n} Acc: {topn_acc:.4f}") if __name__ == "__main__": # Tiny ImageNet 資料集路徑 data_dir = "./datasets/tiny-imagenet-200/" batch_size = 256 epochs = 20 learning_rate = 1e-4 top_n = 5 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加載資料集 train_loader, test_loader = load_tiny_imagenet(data_dir, batch_size) # 初始化模型 encoder = Encoder() classifier = Classifier(num_classes=200) # 訓練模型 train_model(encoder, classifier, train_loader, test_loader, device, epochs, learning_rate, top_n)
Leave a Comment